Skip to main content

ferrolearn_tree/
hist_gradient_boosting.rs

1//! Histogram-based gradient boosting classifiers and regressors.
2//!
3//! This module provides [`HistGradientBoostingClassifier`] and
4//! [`HistGradientBoostingRegressor`], which implement histogram-based gradient
5//! boosting trees inspired by LightGBM / scikit-learn's HistGradientBoosting*.
6//!
7//! # Key design choices
8//!
9//! - **Feature binning**: continuous features are discretised into up to
10//!   `max_bins` (default 256) bins using quantile-based bin edges. NaN values
11//!   are assigned a dedicated bin.
12//! - **Histogram-based split finding**: at each node, gradient/hessian sums are
13//!   accumulated into per-bin histograms, making split finding O(n_bins) instead
14//!   of O(n log n).
15//! - **Subtraction trick**: the child histogram with more samples is computed by
16//!   subtracting the smaller child's histogram from the parent, halving the work.
17//! - **Missing value support**: NaN values are routed to the bin that yields the
18//!   best split, enabling native handling of missing data.
19//!
20//! # Regression Losses
21//!
22//! - **`LeastSquares`**: mean squared error; gradient = `y - F(x)`, hessian = 1.
23//! - **`LeastAbsoluteDeviation`**: gradient = `sign(y - F(x))`, hessian = 1.
24//!
25//! # Classification Loss
26//!
27//! - **`LogLoss`**: binary and multiclass logistic loss (one-vs-rest via softmax).
28//!
29//! # Examples
30//!
31//! ```
32//! use ferrolearn_tree::HistGradientBoostingRegressor;
33//! use ferrolearn_core::{Fit, Predict};
34//! use ndarray::{array, Array1, Array2};
35//!
36//! let x = Array2::from_shape_vec((8, 1), vec![
37//!     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
38//! ]).unwrap();
39//! let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
40//!
41//! let model = HistGradientBoostingRegressor::<f64>::new()
42//!     .with_n_estimators(50)
43//!     .with_learning_rate(0.1)
44//!     .with_random_state(42);
45//! let fitted = model.fit(&x, &y).unwrap();
46//! let preds = fitted.predict(&x).unwrap();
47//! assert_eq!(preds.len(), 8);
48//! ```
49
50use ferrolearn_core::error::FerroError;
51use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
52use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
53use ferrolearn_core::traits::{Fit, Predict};
54use ndarray::{Array1, Array2};
55use num_traits::{Float, FromPrimitive, ToPrimitive};
56use serde::{Deserialize, Serialize};
57
58// ---------------------------------------------------------------------------
59// Loss enums
60// ---------------------------------------------------------------------------
61
62/// Loss function for histogram-based gradient boosting regression.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum HistRegressionLoss {
65    /// Least squares (L2) loss.
66    LeastSquares,
67    /// Least absolute deviation (L1) loss.
68    LeastAbsoluteDeviation,
69}
70
71/// Loss function for histogram-based gradient boosting classification.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum HistClassificationLoss {
74    /// Log-loss (logistic / cross-entropy) for binary and multiclass.
75    LogLoss,
76}
77
78// ---------------------------------------------------------------------------
79// Internal: histogram tree node
80// ---------------------------------------------------------------------------
81
82/// A node in a histogram-based gradient boosting tree.
83///
84/// Unlike [`crate::decision_tree::Node`], thresholds are stored as bin indices
85/// rather than raw feature values. Prediction at a split checks whether the
86/// sample's bin index for the feature is `<= threshold_bin`.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum HistNode<F> {
89    /// An internal split node.
90    Split {
91        /// Feature index used for the split.
92        feature: usize,
93        /// Bin index threshold; samples with `bin[feature] <= threshold_bin`
94        /// go left.
95        threshold_bin: u16,
96        /// Whether NaN values should go left (`true`) or right (`false`).
97        nan_goes_left: bool,
98        /// Index of the left child node in the flat vec.
99        left: usize,
100        /// Index of the right child node in the flat vec.
101        right: usize,
102        /// Weighted gain from this split (for feature importance).
103        gain: F,
104        /// Number of samples that reached this node during training.
105        n_samples: usize,
106    },
107    /// A leaf node that stores a prediction value.
108    Leaf {
109        /// Predicted value (raw leaf output, e.g. gradient step).
110        value: F,
111        /// Number of samples that reached this node during training.
112        n_samples: usize,
113    },
114}
115
116// ---------------------------------------------------------------------------
117// Internal: binning
118// ---------------------------------------------------------------------------
119
120/// The special bin index reserved for NaN / missing values.
121const NAN_BIN: u16 = u16::MAX;
122
123/// Bin edges for a single feature, plus the number of non-NaN bins.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125struct FeatureBinInfo<F> {
126    /// Sorted bin edges (upper thresholds). `edges[b]` is the upper bound
127    /// for bin `b`. The last bin captures everything above `edges[len-2]`.
128    edges: Vec<F>,
129    /// Number of non-NaN bins for this feature (at most `max_bins`).
130    n_bins: u16,
131    /// Whether any NaN was observed for this feature during binning.
132    has_nan: bool,
133}
134
135/// Compute quantile-based bin edges for every feature.
136fn compute_bin_edges<F: Float>(x: &Array2<F>, max_bins: u16) -> Vec<FeatureBinInfo<F>> {
137    let n_features = x.ncols();
138    let n_samples = x.nrows();
139    let mut infos = Vec::with_capacity(n_features);
140
141    for j in 0..n_features {
142        let col = x.column(j);
143        // Separate non-NaN values and sort them.
144        let mut vals: Vec<F> = col.iter().copied().filter(|v| !v.is_nan()).collect();
145        vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
146
147        let has_nan = vals.len() < n_samples;
148
149        if vals.is_empty() {
150            // All NaN.
151            infos.push(FeatureBinInfo {
152                edges: Vec::new(),
153                n_bins: 0,
154                has_nan: true,
155            });
156            continue;
157        }
158
159        // Deduplicate to find unique values.
160        let mut unique: Vec<F> = Vec::new();
161        for &v in &vals {
162            if unique.is_empty() || (v - *unique.last().unwrap()).abs() > F::epsilon() {
163                unique.push(v);
164            }
165        }
166
167        let n_unique = unique.len();
168        let actual_bins = (max_bins as usize).min(n_unique);
169
170        if actual_bins <= 1 {
171            // Only one unique value — one bin, edge is that value.
172            infos.push(FeatureBinInfo {
173                edges: vec![unique[0]],
174                n_bins: 1,
175                has_nan,
176            });
177            continue;
178        }
179
180        // Compute quantile-based bin edges.
181        let mut edges = Vec::with_capacity(actual_bins);
182        for b in 1..actual_bins {
183            let frac = b as f64 / actual_bins as f64;
184            let idx_f = frac * (n_unique as f64 - 1.0);
185            let lo = idx_f.floor() as usize;
186            let hi = (lo + 1).min(n_unique - 1);
187            let t = F::from(idx_f - lo as f64).unwrap();
188            let edge = unique[lo] * (F::one() - t) + unique[hi] * t;
189            // Avoid duplicate edges.
190            if edges.is_empty() || (edge - *edges.last().unwrap()).abs() > F::epsilon() {
191                edges.push(edge);
192            }
193        }
194        // Add a final edge for the upper bound (max value).
195        let last = *unique.last().unwrap();
196        if edges.is_empty() || (last - *edges.last().unwrap()).abs() > F::epsilon() {
197            edges.push(last);
198        }
199
200        let n_bins = edges.len() as u16;
201        infos.push(FeatureBinInfo {
202            edges,
203            n_bins,
204            has_nan,
205        });
206    }
207
208    infos
209}
210
211/// Map a single feature value to its bin index given bin edges.
212#[inline]
213fn map_to_bin<F: Float>(value: F, info: &FeatureBinInfo<F>) -> u16 {
214    if value.is_nan() {
215        return NAN_BIN;
216    }
217    if info.n_bins == 0 {
218        return NAN_BIN;
219    }
220    // Binary search: find the first edge >= value.
221    let edges = &info.edges;
222    let mut lo: usize = 0;
223    let mut hi: usize = edges.len();
224    while lo < hi {
225        let mid = lo + (hi - lo) / 2;
226        if edges[mid] < value {
227            lo = mid + 1;
228        } else {
229            hi = mid;
230        }
231    }
232    if lo >= edges.len() {
233        (info.n_bins - 1).min(edges.len() as u16 - 1)
234    } else {
235        lo as u16
236    }
237}
238
239/// Bin all samples for all features.
240fn bin_data<F: Float>(x: &Array2<F>, bin_infos: &[FeatureBinInfo<F>]) -> Vec<Vec<u16>> {
241    let (n_samples, n_features) = x.dim();
242    // bins[i][j] = bin index for sample i, feature j.
243    let mut bins = vec![vec![0u16; n_features]; n_samples];
244    for j in 0..n_features {
245        for i in 0..n_samples {
246            bins[i][j] = map_to_bin(x[[i, j]], &bin_infos[j]);
247        }
248    }
249    bins
250}
251
252// ---------------------------------------------------------------------------
253// Internal: histogram-based tree building
254// ---------------------------------------------------------------------------
255
256/// Gradient/hessian accumulator for a single bin.
257#[derive(Debug, Clone, Copy, Default)]
258struct BinEntry<F> {
259    grad_sum: F,
260    hess_sum: F,
261    count: usize,
262}
263
264/// A histogram for one feature: one entry per bin, plus an optional NaN entry.
265struct FeatureHistogram<F> {
266    bins: Vec<BinEntry<F>>,
267    nan_entry: BinEntry<F>,
268}
269
270impl<F: Float> FeatureHistogram<F> {
271    fn new(n_bins: u16) -> Self {
272        Self {
273            bins: vec![
274                BinEntry {
275                    grad_sum: F::zero(),
276                    hess_sum: F::zero(),
277                    count: 0,
278                };
279                n_bins as usize
280            ],
281            nan_entry: BinEntry {
282                grad_sum: F::zero(),
283                hess_sum: F::zero(),
284                count: 0,
285            },
286        }
287    }
288}
289
290/// Build histograms for all features from scratch over the given sample indices.
291fn build_histograms<F: Float>(
292    bin_data: &[Vec<u16>],
293    gradients: &[F],
294    hessians: &[F],
295    sample_indices: &[usize],
296    bin_infos: &[FeatureBinInfo<F>],
297) -> Vec<FeatureHistogram<F>> {
298    let n_features = bin_infos.len();
299    let mut histograms: Vec<FeatureHistogram<F>> = bin_infos
300        .iter()
301        .map(|info| FeatureHistogram::new(info.n_bins))
302        .collect();
303
304    for &i in sample_indices {
305        let g = gradients[i];
306        let h = hessians[i];
307        for j in 0..n_features {
308            let b = bin_data[i][j];
309            if b == NAN_BIN {
310                histograms[j].nan_entry.grad_sum = histograms[j].nan_entry.grad_sum + g;
311                histograms[j].nan_entry.hess_sum = histograms[j].nan_entry.hess_sum + h;
312                histograms[j].nan_entry.count += 1;
313            } else {
314                let entry = &mut histograms[j].bins[b as usize];
315                entry.grad_sum = entry.grad_sum + g;
316                entry.hess_sum = entry.hess_sum + h;
317                entry.count += 1;
318            }
319        }
320    }
321
322    histograms
323}
324
325/// Subtraction trick: compute child histogram from parent and sibling.
326fn subtract_histograms<F: Float>(
327    parent: &[FeatureHistogram<F>],
328    sibling: &[FeatureHistogram<F>],
329    bin_infos: &[FeatureBinInfo<F>],
330) -> Vec<FeatureHistogram<F>> {
331    let n_features = bin_infos.len();
332    let mut result: Vec<FeatureHistogram<F>> = bin_infos
333        .iter()
334        .map(|info| FeatureHistogram::new(info.n_bins))
335        .collect();
336
337    for j in 0..n_features {
338        let n_bins = bin_infos[j].n_bins as usize;
339        for b in 0..n_bins {
340            result[j].bins[b].grad_sum = parent[j].bins[b].grad_sum - sibling[j].bins[b].grad_sum;
341            result[j].bins[b].hess_sum = parent[j].bins[b].hess_sum - sibling[j].bins[b].hess_sum;
342            // count is usize so we need saturating_sub for safety
343            result[j].bins[b].count = parent[j].bins[b]
344                .count
345                .saturating_sub(sibling[j].bins[b].count);
346        }
347        // NaN bin
348        result[j].nan_entry.grad_sum = parent[j].nan_entry.grad_sum - sibling[j].nan_entry.grad_sum;
349        result[j].nan_entry.hess_sum = parent[j].nan_entry.hess_sum - sibling[j].nan_entry.hess_sum;
350        result[j].nan_entry.count = parent[j]
351            .nan_entry
352            .count
353            .saturating_sub(sibling[j].nan_entry.count);
354    }
355
356    result
357}
358
359/// Result of finding the best split for a node.
360struct SplitCandidate<F> {
361    feature: usize,
362    threshold_bin: u16,
363    gain: F,
364    nan_goes_left: bool,
365}
366
367/// Find the best split across all features from histograms.
368///
369/// Uses the standard gain formula:
370///   gain = (G_L^2 / (H_L + lambda)) + (G_R^2 / (H_R + lambda)) - (G_parent^2 / (H_parent + lambda))
371fn find_best_split_from_histograms<F: Float>(
372    histograms: &[FeatureHistogram<F>],
373    bin_infos: &[FeatureBinInfo<F>],
374    total_grad: F,
375    total_hess: F,
376    total_count: usize,
377    l2_regularization: F,
378    min_samples_leaf: usize,
379) -> Option<SplitCandidate<F>> {
380    let n_features = bin_infos.len();
381    let parent_gain = total_grad * total_grad / (total_hess + l2_regularization);
382
383    let mut best: Option<SplitCandidate<F>> = None;
384
385    for j in 0..n_features {
386        let n_bins = bin_infos[j].n_bins as usize;
387        if n_bins <= 1 {
388            continue;
389        }
390        let nan = &histograms[j].nan_entry;
391
392        // Try scanning left-to-right through bins.
393        // For each split position b, left contains bins 0..=b, right contains b+1..n_bins-1.
394        // NaN samples can go either left or right; we try both.
395        let mut left_grad = F::zero();
396        let mut left_hess = F::zero();
397        let mut left_count: usize = 0;
398
399        for b in 0..(n_bins - 1) {
400            let entry = &histograms[j].bins[b];
401            left_grad = left_grad + entry.grad_sum;
402            left_hess = left_hess + entry.hess_sum;
403            left_count += entry.count;
404
405            let right_grad_no_nan = total_grad - left_grad - nan.grad_sum;
406            let right_hess_no_nan = total_hess - left_hess - nan.hess_sum;
407            let right_count_no_nan = total_count
408                .saturating_sub(left_count)
409                .saturating_sub(nan.count);
410
411            // Try NaN goes left.
412            {
413                let lg = left_grad + nan.grad_sum;
414                let lh = left_hess + nan.hess_sum;
415                let lc = left_count + nan.count;
416                let rg = right_grad_no_nan;
417                let rh = right_hess_no_nan;
418                let rc = right_count_no_nan;
419
420                if lc >= min_samples_leaf && rc >= min_samples_leaf {
421                    let gain = lg * lg / (lh + l2_regularization)
422                        + rg * rg / (rh + l2_regularization)
423                        - parent_gain;
424                    if gain > F::zero() {
425                        let better = match &best {
426                            None => true,
427                            Some(curr) => gain > curr.gain,
428                        };
429                        if better {
430                            best = Some(SplitCandidate {
431                                feature: j,
432                                threshold_bin: b as u16,
433                                gain,
434                                nan_goes_left: true,
435                            });
436                        }
437                    }
438                }
439            }
440
441            // Try NaN goes right.
442            {
443                let lg = left_grad;
444                let lh = left_hess;
445                let lc = left_count;
446                let rg = right_grad_no_nan + nan.grad_sum;
447                let rh = right_hess_no_nan + nan.hess_sum;
448                let rc = right_count_no_nan + nan.count;
449
450                if lc >= min_samples_leaf && rc >= min_samples_leaf {
451                    let gain = lg * lg / (lh + l2_regularization)
452                        + rg * rg / (rh + l2_regularization)
453                        - parent_gain;
454                    if gain > F::zero() {
455                        let better = match &best {
456                            None => true,
457                            Some(curr) => gain > curr.gain,
458                        };
459                        if better {
460                            best = Some(SplitCandidate {
461                                feature: j,
462                                threshold_bin: b as u16,
463                                gain,
464                                nan_goes_left: false,
465                            });
466                        }
467                    }
468                }
469            }
470        }
471    }
472
473    best
474}
475
476/// Compute leaf value: -G / (H + lambda).
477#[inline]
478fn compute_leaf_value<F: Float>(grad_sum: F, hess_sum: F, l2_reg: F) -> F {
479    if hess_sum.abs() < F::epsilon() {
480        F::zero()
481    } else {
482        -grad_sum / (hess_sum + l2_reg)
483    }
484}
485
486/// Parameters for histogram tree building.
487struct HistTreeParams<F> {
488    max_depth: Option<usize>,
489    min_samples_leaf: usize,
490    max_leaf_nodes: Option<usize>,
491    l2_regularization: F,
492}
493
494/// Build a single histogram-based regression tree.
495///
496/// Returns a vector of `HistNode`s (flat representation).
497fn build_hist_tree<F: Float>(
498    binned: &[Vec<u16>],
499    gradients: &[F],
500    hessians: &[F],
501    sample_indices: &[usize],
502    bin_infos: &[FeatureBinInfo<F>],
503    params: &HistTreeParams<F>,
504) -> Vec<HistNode<F>> {
505    let n_features = bin_infos.len();
506    let _ = n_features; // used implicitly through bin_infos
507
508    if sample_indices.is_empty() {
509        return vec![HistNode::Leaf {
510            value: F::zero(),
511            n_samples: 0,
512        }];
513    }
514
515    // Use a work queue approach for best-first (leaf-wise) or depth-first growth.
516    // If max_leaf_nodes is set, use best-first; otherwise depth-first.
517    if params.max_leaf_nodes.is_some() {
518        build_hist_tree_best_first(
519            binned,
520            gradients,
521            hessians,
522            sample_indices,
523            bin_infos,
524            params,
525        )
526    } else {
527        let mut nodes = Vec::new();
528        let hist = build_histograms(binned, gradients, hessians, sample_indices, bin_infos);
529        build_hist_tree_recursive(
530            binned,
531            gradients,
532            hessians,
533            sample_indices,
534            bin_infos,
535            params,
536            &hist,
537            0,
538            &mut nodes,
539        );
540        nodes
541    }
542}
543
544/// Recursive depth-first histogram tree building.
545#[allow(clippy::too_many_arguments)]
546fn build_hist_tree_recursive<F: Float>(
547    binned: &[Vec<u16>],
548    gradients: &[F],
549    hessians: &[F],
550    sample_indices: &[usize],
551    bin_infos: &[FeatureBinInfo<F>],
552    params: &HistTreeParams<F>,
553    histograms: &[FeatureHistogram<F>],
554    depth: usize,
555    nodes: &mut Vec<HistNode<F>>,
556) -> usize {
557    let n = sample_indices.len();
558    let grad_sum: F = sample_indices
559        .iter()
560        .map(|&i| gradients[i])
561        .fold(F::zero(), |a, b| a + b);
562    let hess_sum: F = sample_indices
563        .iter()
564        .map(|&i| hessians[i])
565        .fold(F::zero(), |a, b| a + b);
566
567    // Check stopping conditions.
568    let at_max_depth = params.max_depth.is_some_and(|d| depth >= d);
569    let too_few = n < 2 * params.min_samples_leaf;
570
571    if at_max_depth || too_few {
572        let idx = nodes.len();
573        nodes.push(HistNode::Leaf {
574            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
575            n_samples: n,
576        });
577        return idx;
578    }
579
580    // Find best split.
581    let split = find_best_split_from_histograms(
582        histograms,
583        bin_infos,
584        grad_sum,
585        hess_sum,
586        n,
587        params.l2_regularization,
588        params.min_samples_leaf,
589    );
590
591    let split = if let Some(s) = split {
592        s
593    } else {
594        let idx = nodes.len();
595        nodes.push(HistNode::Leaf {
596            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
597            n_samples: n,
598        });
599        return idx;
600    };
601
602    // Partition samples into left and right.
603    let (left_indices, right_indices): (Vec<usize>, Vec<usize>) =
604        sample_indices.iter().partition(|&&i| {
605            let b = binned[i][split.feature];
606            if b == NAN_BIN {
607                split.nan_goes_left
608            } else {
609                b <= split.threshold_bin
610            }
611        });
612
613    if left_indices.is_empty() || right_indices.is_empty() {
614        let idx = nodes.len();
615        nodes.push(HistNode::Leaf {
616            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
617            n_samples: n,
618        });
619        return idx;
620    }
621
622    // Build histograms for the smaller child, then use subtraction trick.
623    let (small_indices, _large_indices, small_is_left) =
624        if left_indices.len() <= right_indices.len() {
625            (&left_indices, &right_indices, true)
626        } else {
627            (&right_indices, &left_indices, false)
628        };
629
630    let small_hist = build_histograms(binned, gradients, hessians, small_indices, bin_infos);
631    let large_hist = subtract_histograms(histograms, &small_hist, bin_infos);
632
633    let (left_hist, right_hist) = if small_is_left {
634        (small_hist, large_hist)
635    } else {
636        (large_hist, small_hist)
637    };
638
639    // Reserve a placeholder for this split node.
640    let node_idx = nodes.len();
641    nodes.push(HistNode::Leaf {
642        value: F::zero(),
643        n_samples: 0,
644    }); // placeholder
645
646    // Recurse.
647    let left_idx = build_hist_tree_recursive(
648        binned,
649        gradients,
650        hessians,
651        &left_indices,
652        bin_infos,
653        params,
654        &left_hist,
655        depth + 1,
656        nodes,
657    );
658    let right_idx = build_hist_tree_recursive(
659        binned,
660        gradients,
661        hessians,
662        &right_indices,
663        bin_infos,
664        params,
665        &right_hist,
666        depth + 1,
667        nodes,
668    );
669
670    nodes[node_idx] = HistNode::Split {
671        feature: split.feature,
672        threshold_bin: split.threshold_bin,
673        nan_goes_left: split.nan_goes_left,
674        left: left_idx,
675        right: right_idx,
676        gain: split.gain,
677        n_samples: n,
678    };
679
680    node_idx
681}
682
683/// Entry in the best-first priority queue.
684struct SplitTask {
685    /// Indices of samples at this node.
686    sample_indices: Vec<usize>,
687    /// The node index in the flat node vec (a leaf placeholder).
688    node_idx: usize,
689    /// Depth of this node.
690    depth: usize,
691    /// The gain of the best split at this node.
692    gain: f64,
693    /// Feature of the best split.
694    feature: usize,
695    /// Bin threshold of the best split.
696    threshold_bin: u16,
697    /// Whether NaN goes left.
698    nan_goes_left: bool,
699}
700
701/// Build a histogram tree using best-first (leaf-wise) growth with max_leaf_nodes.
702fn build_hist_tree_best_first<F: Float>(
703    binned: &[Vec<u16>],
704    gradients: &[F],
705    hessians: &[F],
706    sample_indices: &[usize],
707    bin_infos: &[FeatureBinInfo<F>],
708    params: &HistTreeParams<F>,
709) -> Vec<HistNode<F>> {
710    let max_leaves = params.max_leaf_nodes.unwrap_or(usize::MAX);
711    let mut nodes: Vec<HistNode<F>> = Vec::new();
712
713    // Root node.
714    let n = sample_indices.len();
715    let grad_sum: F = sample_indices
716        .iter()
717        .map(|&i| gradients[i])
718        .fold(F::zero(), |a, b| a + b);
719    let hess_sum: F = sample_indices
720        .iter()
721        .map(|&i| hessians[i])
722        .fold(F::zero(), |a, b| a + b);
723
724    let root_idx = nodes.len();
725    nodes.push(HistNode::Leaf {
726        value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
727        n_samples: n,
728    });
729
730    let root_hist = build_histograms(binned, gradients, hessians, sample_indices, bin_infos);
731    let root_split = find_best_split_from_histograms(
732        &root_hist,
733        bin_infos,
734        grad_sum,
735        hess_sum,
736        n,
737        params.l2_regularization,
738        params.min_samples_leaf,
739    );
740
741    let mut pending: Vec<(SplitTask, Vec<FeatureHistogram<F>>)> = Vec::new();
742    let mut n_leaves: usize = 1;
743
744    if let Some(split) = root_split {
745        let at_max_depth = params.max_depth.is_some_and(|d| d == 0);
746        if !at_max_depth {
747            pending.push((
748                SplitTask {
749                    sample_indices: sample_indices.to_vec(),
750                    node_idx: root_idx,
751                    depth: 0,
752                    gain: split.gain.to_f64().unwrap_or(0.0),
753                    feature: split.feature,
754                    threshold_bin: split.threshold_bin,
755                    nan_goes_left: split.nan_goes_left,
756                },
757                root_hist,
758            ));
759        }
760    }
761
762    while !pending.is_empty() && n_leaves < max_leaves {
763        // Pick the task with the highest gain.
764        let best_idx = pending
765            .iter()
766            .enumerate()
767            .max_by(|(_, a), (_, b)| {
768                a.0.gain
769                    .partial_cmp(&b.0.gain)
770                    .unwrap_or(std::cmp::Ordering::Equal)
771            })
772            .map(|(i, _)| i)
773            .unwrap();
774
775        let (task, parent_hist) = pending.swap_remove(best_idx);
776
777        // Partition.
778        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) =
779            task.sample_indices.iter().partition(|&&i| {
780                let b = binned[i][task.feature];
781                if b == NAN_BIN {
782                    task.nan_goes_left
783                } else {
784                    b <= task.threshold_bin
785                }
786            });
787
788        if left_indices.is_empty() || right_indices.is_empty() {
789            continue;
790        }
791
792        // Build histograms using subtraction trick.
793        let (small_indices, _large_indices, small_is_left) =
794            if left_indices.len() <= right_indices.len() {
795                (&left_indices, &right_indices, true)
796            } else {
797                (&right_indices, &left_indices, false)
798            };
799
800        let small_hist = build_histograms(binned, gradients, hessians, small_indices, bin_infos);
801        let large_hist = subtract_histograms(&parent_hist, &small_hist, bin_infos);
802
803        let (left_hist, right_hist) = if small_is_left {
804            (small_hist, large_hist)
805        } else {
806            (large_hist, small_hist)
807        };
808
809        // Create left and right leaf nodes.
810        let left_grad: F = left_indices
811            .iter()
812            .map(|&i| gradients[i])
813            .fold(F::zero(), |a, b| a + b);
814        let left_hess: F = left_indices
815            .iter()
816            .map(|&i| hessians[i])
817            .fold(F::zero(), |a, b| a + b);
818        let right_grad: F = right_indices
819            .iter()
820            .map(|&i| gradients[i])
821            .fold(F::zero(), |a, b| a + b);
822        let right_hess: F = right_indices
823            .iter()
824            .map(|&i| hessians[i])
825            .fold(F::zero(), |a, b| a + b);
826
827        let left_idx = nodes.len();
828        nodes.push(HistNode::Leaf {
829            value: compute_leaf_value(left_grad, left_hess, params.l2_regularization),
830            n_samples: left_indices.len(),
831        });
832        let right_idx = nodes.len();
833        nodes.push(HistNode::Leaf {
834            value: compute_leaf_value(right_grad, right_hess, params.l2_regularization),
835            n_samples: right_indices.len(),
836        });
837
838        // Convert the parent leaf placeholder into a split node.
839        nodes[task.node_idx] = HistNode::Split {
840            feature: task.feature,
841            threshold_bin: task.threshold_bin,
842            nan_goes_left: task.nan_goes_left,
843            left: left_idx,
844            right: right_idx,
845            gain: F::from(task.gain).unwrap(),
846            n_samples: task.sample_indices.len(),
847        };
848
849        // One leaf became two, so net +1 leaf.
850        n_leaves += 1;
851
852        let child_depth = task.depth + 1;
853        let at_max_depth = params.max_depth.is_some_and(|d| child_depth >= d);
854
855        if !at_max_depth && n_leaves < max_leaves {
856            // Try to split left child.
857            if left_indices.len() >= 2 * params.min_samples_leaf {
858                let left_split = find_best_split_from_histograms(
859                    &left_hist,
860                    bin_infos,
861                    left_grad,
862                    left_hess,
863                    left_indices.len(),
864                    params.l2_regularization,
865                    params.min_samples_leaf,
866                );
867                if let Some(s) = left_split {
868                    pending.push((
869                        SplitTask {
870                            sample_indices: left_indices,
871                            node_idx: left_idx,
872                            depth: child_depth,
873                            gain: s.gain.to_f64().unwrap_or(0.0),
874                            feature: s.feature,
875                            threshold_bin: s.threshold_bin,
876                            nan_goes_left: s.nan_goes_left,
877                        },
878                        left_hist,
879                    ));
880                }
881            }
882
883            // Try to split right child.
884            if right_indices.len() >= 2 * params.min_samples_leaf {
885                let right_split = find_best_split_from_histograms(
886                    &right_hist,
887                    bin_infos,
888                    right_grad,
889                    right_hess,
890                    right_indices.len(),
891                    params.l2_regularization,
892                    params.min_samples_leaf,
893                );
894                if let Some(s) = right_split {
895                    pending.push((
896                        SplitTask {
897                            sample_indices: right_indices,
898                            node_idx: right_idx,
899                            depth: child_depth,
900                            gain: s.gain.to_f64().unwrap_or(0.0),
901                            feature: s.feature,
902                            threshold_bin: s.threshold_bin,
903                            nan_goes_left: s.nan_goes_left,
904                        },
905                        right_hist,
906                    ));
907                }
908            }
909        }
910    }
911
912    nodes
913}
914
915/// Traverse a histogram tree to find the leaf for a single binned sample.
916#[inline]
917fn traverse_hist_tree<F: Float>(nodes: &[HistNode<F>], sample_bins: &[u16]) -> usize {
918    let mut idx = 0;
919    loop {
920        match &nodes[idx] {
921            HistNode::Split {
922                feature,
923                threshold_bin,
924                nan_goes_left,
925                left,
926                right,
927                ..
928            } => {
929                let b = sample_bins[*feature];
930                if b == NAN_BIN {
931                    idx = if *nan_goes_left { *left } else { *right };
932                } else if b <= *threshold_bin {
933                    idx = *left;
934                } else {
935                    idx = *right;
936                }
937            }
938            HistNode::Leaf { .. } => return idx,
939        }
940    }
941}
942
943/// Compute feature importances from a histogram tree's gain values.
944fn compute_hist_feature_importances<F: Float>(
945    nodes: &[HistNode<F>],
946    n_features: usize,
947) -> Array1<F> {
948    let mut importances = Array1::zeros(n_features);
949    for node in nodes {
950        if let HistNode::Split { feature, gain, .. } = node {
951            importances[*feature] = importances[*feature] + *gain;
952        }
953    }
954    importances
955}
956
957// ---------------------------------------------------------------------------
958// Internal helpers
959// ---------------------------------------------------------------------------
960
961/// Sigmoid function: 1 / (1 + exp(-x)).
962fn sigmoid<F: Float>(x: F) -> F {
963    F::one() / (F::one() + (-x).exp())
964}
965
966/// Compute softmax probabilities for each class across all samples.
967///
968/// Returns `probs[k][i]` = probability of class k for sample i.
969fn softmax_matrix<F: Float>(
970    f_vals: &[Array1<F>],
971    n_samples: usize,
972    n_classes: usize,
973) -> Vec<Vec<F>> {
974    let mut probs: Vec<Vec<F>> = vec![vec![F::zero(); n_samples]; n_classes];
975    for i in 0..n_samples {
976        let max_val = (0..n_classes)
977            .map(|k| f_vals[k][i])
978            .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
979        let mut sum = F::zero();
980        let mut exps = vec![F::zero(); n_classes];
981        for k in 0..n_classes {
982            exps[k] = (f_vals[k][i] - max_val).exp();
983            sum = sum + exps[k];
984        }
985        let eps = F::from(1e-15).unwrap();
986        if sum < eps {
987            sum = eps;
988        }
989        for k in 0..n_classes {
990            probs[k][i] = exps[k] / sum;
991        }
992    }
993    probs
994}
995
996// ---------------------------------------------------------------------------
997// HistGradientBoostingRegressor
998// ---------------------------------------------------------------------------
999
1000/// Histogram-based gradient boosting regressor.
1001///
1002/// Uses quantile-based feature binning and gradient/hessian histograms for
1003/// O(n_bins) split finding per node. This is significantly faster than the
1004/// standard [`GradientBoostingRegressor`](crate::GradientBoostingRegressor)
1005/// for larger datasets.
1006///
1007/// # Type Parameters
1008///
1009/// - `F`: The floating-point type (`f32` or `f64`).
1010#[derive(Debug, Clone, Serialize, Deserialize)]
1011pub struct HistGradientBoostingRegressor<F> {
1012    /// Number of boosting stages (trees).
1013    pub n_estimators: usize,
1014    /// Learning rate (shrinkage) applied to each tree's contribution.
1015    pub learning_rate: f64,
1016    /// Maximum depth of each tree.
1017    pub max_depth: Option<usize>,
1018    /// Minimum number of samples required in a leaf node.
1019    pub min_samples_leaf: usize,
1020    /// Maximum number of bins for feature discretisation (at most 256).
1021    pub max_bins: u16,
1022    /// L2 regularization term on weights.
1023    pub l2_regularization: f64,
1024    /// Maximum number of leaf nodes per tree (best-first growth).
1025    /// If `None`, depth-first growth is used with `max_depth`.
1026    pub max_leaf_nodes: Option<usize>,
1027    /// Loss function.
1028    pub loss: HistRegressionLoss,
1029    /// Random seed for reproducibility.
1030    pub random_state: Option<u64>,
1031    _marker: std::marker::PhantomData<F>,
1032}
1033
1034impl<F: Float> HistGradientBoostingRegressor<F> {
1035    /// Create a new `HistGradientBoostingRegressor` with default settings.
1036    ///
1037    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
1038    /// `max_depth = None`, `min_samples_leaf = 20`,
1039    /// `max_bins = 255`, `l2_regularization = 0.0`,
1040    /// `max_leaf_nodes = Some(31)`, `loss = LeastSquares`.
1041    #[must_use]
1042    pub fn new() -> Self {
1043        Self {
1044            n_estimators: 100,
1045            learning_rate: 0.1,
1046            max_depth: None,
1047            min_samples_leaf: 20,
1048            max_bins: 255,
1049            l2_regularization: 0.0,
1050            max_leaf_nodes: Some(31),
1051            loss: HistRegressionLoss::LeastSquares,
1052            random_state: None,
1053            _marker: std::marker::PhantomData,
1054        }
1055    }
1056
1057    /// Set the number of boosting stages.
1058    #[must_use]
1059    pub fn with_n_estimators(mut self, n: usize) -> Self {
1060        self.n_estimators = n;
1061        self
1062    }
1063
1064    /// Set the learning rate (shrinkage).
1065    #[must_use]
1066    pub fn with_learning_rate(mut self, lr: f64) -> Self {
1067        self.learning_rate = lr;
1068        self
1069    }
1070
1071    /// Set the maximum tree depth.
1072    #[must_use]
1073    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
1074        self.max_depth = d;
1075        self
1076    }
1077
1078    /// Set the minimum number of samples in a leaf.
1079    #[must_use]
1080    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
1081        self.min_samples_leaf = n;
1082        self
1083    }
1084
1085    /// Set the maximum number of bins for feature discretisation.
1086    #[must_use]
1087    pub fn with_max_bins(mut self, bins: u16) -> Self {
1088        self.max_bins = bins;
1089        self
1090    }
1091
1092    /// Set the L2 regularization term.
1093    #[must_use]
1094    pub fn with_l2_regularization(mut self, reg: f64) -> Self {
1095        self.l2_regularization = reg;
1096        self
1097    }
1098
1099    /// Set the maximum number of leaf nodes (best-first growth).
1100    #[must_use]
1101    pub fn with_max_leaf_nodes(mut self, n: Option<usize>) -> Self {
1102        self.max_leaf_nodes = n;
1103        self
1104    }
1105
1106    /// Set the loss function.
1107    #[must_use]
1108    pub fn with_loss(mut self, loss: HistRegressionLoss) -> Self {
1109        self.loss = loss;
1110        self
1111    }
1112
1113    /// Set the random seed for reproducibility.
1114    #[must_use]
1115    pub fn with_random_state(mut self, seed: u64) -> Self {
1116        self.random_state = Some(seed);
1117        self
1118    }
1119}
1120
1121impl<F: Float> Default for HistGradientBoostingRegressor<F> {
1122    fn default() -> Self {
1123        Self::new()
1124    }
1125}
1126
1127// ---------------------------------------------------------------------------
1128// FittedHistGradientBoostingRegressor
1129// ---------------------------------------------------------------------------
1130
1131/// A fitted histogram-based gradient boosting regressor.
1132///
1133/// Stores the binning information, initial prediction, and the sequence of
1134/// fitted histogram trees. Predictions are computed by binning the input
1135/// features and traversing each tree.
1136#[derive(Debug, Clone)]
1137pub struct FittedHistGradientBoostingRegressor<F> {
1138    /// Bin edge information for each feature.
1139    bin_infos: Vec<FeatureBinInfo<F>>,
1140    /// Initial prediction (baseline).
1141    init: F,
1142    /// Learning rate used during training.
1143    learning_rate: F,
1144    /// Sequence of fitted histogram trees.
1145    trees: Vec<Vec<HistNode<F>>>,
1146    /// Number of features.
1147    n_features: usize,
1148    /// Per-feature importance scores (normalised).
1149    feature_importances: Array1<F>,
1150}
1151
1152impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>>
1153    for HistGradientBoostingRegressor<F>
1154{
1155    type Fitted = FittedHistGradientBoostingRegressor<F>;
1156    type Error = FerroError;
1157
1158    /// Fit the histogram-based gradient boosting regressor.
1159    ///
1160    /// # Errors
1161    ///
1162    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
1163    /// numbers of samples.
1164    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
1165    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
1166    fn fit(
1167        &self,
1168        x: &Array2<F>,
1169        y: &Array1<F>,
1170    ) -> Result<FittedHistGradientBoostingRegressor<F>, FerroError> {
1171        let (n_samples, n_features) = x.dim();
1172
1173        // Validate inputs.
1174        if n_samples != y.len() {
1175            return Err(FerroError::ShapeMismatch {
1176                expected: vec![n_samples],
1177                actual: vec![y.len()],
1178                context: "y length must match number of samples in X".into(),
1179            });
1180        }
1181        if n_samples == 0 {
1182            return Err(FerroError::InsufficientSamples {
1183                required: 1,
1184                actual: 0,
1185                context: "HistGradientBoostingRegressor requires at least one sample".into(),
1186            });
1187        }
1188        if self.n_estimators == 0 {
1189            return Err(FerroError::InvalidParameter {
1190                name: "n_estimators".into(),
1191                reason: "must be at least 1".into(),
1192            });
1193        }
1194        if self.learning_rate <= 0.0 {
1195            return Err(FerroError::InvalidParameter {
1196                name: "learning_rate".into(),
1197                reason: "must be positive".into(),
1198            });
1199        }
1200        if self.max_bins < 2 {
1201            return Err(FerroError::InvalidParameter {
1202                name: "max_bins".into(),
1203                reason: "must be at least 2".into(),
1204            });
1205        }
1206
1207        let lr = F::from(self.learning_rate).unwrap();
1208        let l2_reg = F::from(self.l2_regularization).unwrap();
1209
1210        // Compute bin edges and bin the data.
1211        let bin_infos = compute_bin_edges(x, self.max_bins);
1212        let binned = bin_data(x, &bin_infos);
1213
1214        // Initial prediction.
1215        let init = match self.loss {
1216            HistRegressionLoss::LeastSquares => {
1217                let sum: F = y.iter().copied().fold(F::zero(), |a, b| a + b);
1218                sum / F::from(n_samples).unwrap()
1219            }
1220            HistRegressionLoss::LeastAbsoluteDeviation => median_f(y),
1221        };
1222
1223        let mut f_vals = Array1::from_elem(n_samples, init);
1224        let all_indices: Vec<usize> = (0..n_samples).collect();
1225
1226        let tree_params = HistTreeParams {
1227            max_depth: self.max_depth,
1228            min_samples_leaf: self.min_samples_leaf,
1229            max_leaf_nodes: self.max_leaf_nodes,
1230            l2_regularization: l2_reg,
1231        };
1232
1233        let mut trees = Vec::with_capacity(self.n_estimators);
1234
1235        for _ in 0..self.n_estimators {
1236            // Compute gradients and hessians.
1237            let (gradients, hessians) = match self.loss {
1238                HistRegressionLoss::LeastSquares => {
1239                    let grads: Vec<F> = (0..n_samples).map(|i| -(y[i] - f_vals[i])).collect();
1240                    let hess: Vec<F> = vec![F::one(); n_samples];
1241                    (grads, hess)
1242                }
1243                HistRegressionLoss::LeastAbsoluteDeviation => {
1244                    let grads: Vec<F> = (0..n_samples)
1245                        .map(|i| {
1246                            let diff = y[i] - f_vals[i];
1247                            if diff > F::zero() {
1248                                -F::one()
1249                            } else if diff < F::zero() {
1250                                F::one()
1251                            } else {
1252                                F::zero()
1253                            }
1254                        })
1255                        .collect();
1256                    let hess: Vec<F> = vec![F::one(); n_samples];
1257                    (grads, hess)
1258                }
1259            };
1260
1261            let tree = build_hist_tree(
1262                &binned,
1263                &gradients,
1264                &hessians,
1265                &all_indices,
1266                &bin_infos,
1267                &tree_params,
1268            );
1269
1270            // Update predictions.
1271            for i in 0..n_samples {
1272                let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1273                if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1274                    f_vals[i] = f_vals[i] + lr * value;
1275                }
1276            }
1277
1278            trees.push(tree);
1279        }
1280
1281        // Compute feature importances.
1282        let mut total_importances = Array1::<F>::zeros(n_features);
1283        for tree_nodes in &trees {
1284            total_importances =
1285                total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1286        }
1287        let imp_sum: F = total_importances
1288            .iter()
1289            .copied()
1290            .fold(F::zero(), |a, b| a + b);
1291        if imp_sum > F::zero() {
1292            total_importances.mapv_inplace(|v| v / imp_sum);
1293        }
1294
1295        Ok(FittedHistGradientBoostingRegressor {
1296            bin_infos,
1297            init,
1298            learning_rate: lr,
1299            trees,
1300            n_features,
1301            feature_importances: total_importances,
1302        })
1303    }
1304}
1305
1306impl<F: Float + Send + Sync + 'static> Predict<Array2<F>>
1307    for FittedHistGradientBoostingRegressor<F>
1308{
1309    type Output = Array1<F>;
1310    type Error = FerroError;
1311
1312    /// Predict target values.
1313    ///
1314    /// # Errors
1315    ///
1316    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
1317    /// not match the fitted model.
1318    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1319        if x.ncols() != self.n_features {
1320            return Err(FerroError::ShapeMismatch {
1321                expected: vec![self.n_features],
1322                actual: vec![x.ncols()],
1323                context: "number of features must match fitted model".into(),
1324            });
1325        }
1326
1327        let n_samples = x.nrows();
1328        let binned = bin_data(x, &self.bin_infos);
1329        let mut predictions = Array1::from_elem(n_samples, self.init);
1330
1331        for i in 0..n_samples {
1332            for tree_nodes in &self.trees {
1333                let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1334                if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1335                    predictions[i] = predictions[i] + self.learning_rate * value;
1336                }
1337            }
1338        }
1339
1340        Ok(predictions)
1341    }
1342}
1343
1344impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1345    for FittedHistGradientBoostingRegressor<F>
1346{
1347    fn feature_importances(&self) -> &Array1<F> {
1348        &self.feature_importances
1349    }
1350}
1351
1352// Pipeline integration.
1353impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for HistGradientBoostingRegressor<F> {
1354    fn fit_pipeline(
1355        &self,
1356        x: &Array2<F>,
1357        y: &Array1<F>,
1358    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1359        let fitted = self.fit(x, y)?;
1360        Ok(Box::new(fitted))
1361    }
1362}
1363
1364impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
1365    for FittedHistGradientBoostingRegressor<F>
1366{
1367    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1368        self.predict(x)
1369    }
1370}
1371
1372impl<F: Float + Send + Sync + 'static> FittedHistGradientBoostingRegressor<F> {
1373    /// R² coefficient of determination on the given test data.
1374    /// Equivalent to sklearn's `RegressorMixin.score`.
1375    ///
1376    /// # Errors
1377    ///
1378    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
1379    /// the feature count does not match the training data.
1380    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<F, FerroError> {
1381        if x.nrows() != y.len() {
1382            return Err(FerroError::ShapeMismatch {
1383                expected: vec![x.nrows()],
1384                actual: vec![y.len()],
1385                context: "y length must match number of samples in X".into(),
1386            });
1387        }
1388        let preds = self.predict(x)?;
1389        Ok(crate::r2_score(&preds, y))
1390    }
1391}
1392
1393// ---------------------------------------------------------------------------
1394// HistGradientBoostingClassifier
1395// ---------------------------------------------------------------------------
1396
1397/// Histogram-based gradient boosting classifier.
1398///
1399/// For binary classification a single model is trained on log-odds residuals.
1400/// For multiclass (*K* classes), *K* histogram trees are built per boosting
1401/// round (one-vs-rest in probability space via softmax).
1402///
1403/// # Type Parameters
1404///
1405/// - `F`: The floating-point type (`f32` or `f64`).
1406#[derive(Debug, Clone, Serialize, Deserialize)]
1407pub struct HistGradientBoostingClassifier<F> {
1408    /// Number of boosting stages.
1409    pub n_estimators: usize,
1410    /// Learning rate (shrinkage).
1411    pub learning_rate: f64,
1412    /// Maximum depth of each tree.
1413    pub max_depth: Option<usize>,
1414    /// Minimum number of samples required in a leaf node.
1415    pub min_samples_leaf: usize,
1416    /// Maximum number of bins for feature discretisation (at most 256).
1417    pub max_bins: u16,
1418    /// L2 regularization term on weights.
1419    pub l2_regularization: f64,
1420    /// Maximum number of leaf nodes per tree (best-first growth).
1421    pub max_leaf_nodes: Option<usize>,
1422    /// Classification loss function.
1423    pub loss: HistClassificationLoss,
1424    /// Random seed for reproducibility (reserved for future subsampling).
1425    pub random_state: Option<u64>,
1426    _marker: std::marker::PhantomData<F>,
1427}
1428
1429impl<F: Float> HistGradientBoostingClassifier<F> {
1430    /// Create a new `HistGradientBoostingClassifier` with default settings.
1431    ///
1432    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
1433    /// `max_depth = None`, `min_samples_leaf = 20`,
1434    /// `max_bins = 255`, `l2_regularization = 0.0`,
1435    /// `max_leaf_nodes = Some(31)`, `loss = LogLoss`.
1436    #[must_use]
1437    pub fn new() -> Self {
1438        Self {
1439            n_estimators: 100,
1440            learning_rate: 0.1,
1441            max_depth: None,
1442            min_samples_leaf: 20,
1443            max_bins: 255,
1444            l2_regularization: 0.0,
1445            max_leaf_nodes: Some(31),
1446            loss: HistClassificationLoss::LogLoss,
1447            random_state: None,
1448            _marker: std::marker::PhantomData,
1449        }
1450    }
1451
1452    /// Set the number of boosting stages.
1453    #[must_use]
1454    pub fn with_n_estimators(mut self, n: usize) -> Self {
1455        self.n_estimators = n;
1456        self
1457    }
1458
1459    /// Set the learning rate (shrinkage).
1460    #[must_use]
1461    pub fn with_learning_rate(mut self, lr: f64) -> Self {
1462        self.learning_rate = lr;
1463        self
1464    }
1465
1466    /// Set the maximum tree depth.
1467    #[must_use]
1468    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
1469        self.max_depth = d;
1470        self
1471    }
1472
1473    /// Set the minimum number of samples in a leaf.
1474    #[must_use]
1475    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
1476        self.min_samples_leaf = n;
1477        self
1478    }
1479
1480    /// Set the maximum number of bins for feature discretisation.
1481    #[must_use]
1482    pub fn with_max_bins(mut self, bins: u16) -> Self {
1483        self.max_bins = bins;
1484        self
1485    }
1486
1487    /// Set the L2 regularization term.
1488    #[must_use]
1489    pub fn with_l2_regularization(mut self, reg: f64) -> Self {
1490        self.l2_regularization = reg;
1491        self
1492    }
1493
1494    /// Set the maximum number of leaf nodes (best-first growth).
1495    #[must_use]
1496    pub fn with_max_leaf_nodes(mut self, n: Option<usize>) -> Self {
1497        self.max_leaf_nodes = n;
1498        self
1499    }
1500
1501    /// Set the random seed for reproducibility.
1502    #[must_use]
1503    pub fn with_random_state(mut self, seed: u64) -> Self {
1504        self.random_state = Some(seed);
1505        self
1506    }
1507}
1508
1509impl<F: Float> Default for HistGradientBoostingClassifier<F> {
1510    fn default() -> Self {
1511        Self::new()
1512    }
1513}
1514
1515// ---------------------------------------------------------------------------
1516// FittedHistGradientBoostingClassifier
1517// ---------------------------------------------------------------------------
1518
1519/// A fitted histogram-based gradient boosting classifier.
1520///
1521/// For binary classification, stores a single sequence of trees predicting log-odds.
1522/// For multiclass, stores `K` sequences of trees (one per class).
1523#[derive(Debug, Clone)]
1524pub struct FittedHistGradientBoostingClassifier<F> {
1525    /// Bin edge information for each feature.
1526    bin_infos: Vec<FeatureBinInfo<F>>,
1527    /// Sorted unique class labels.
1528    classes: Vec<usize>,
1529    /// Initial predictions per class (log-odds or log-prior).
1530    init: Vec<F>,
1531    /// Learning rate.
1532    learning_rate: F,
1533    /// Trees: for binary, `trees[0]` has all trees. For multiclass,
1534    /// `trees[k]` has trees for class k.
1535    trees: Vec<Vec<Vec<HistNode<F>>>>,
1536    /// Number of features.
1537    n_features: usize,
1538    /// Per-feature importance scores (normalised).
1539    feature_importances: Array1<F>,
1540}
1541
1542impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>>
1543    for HistGradientBoostingClassifier<F>
1544{
1545    type Fitted = FittedHistGradientBoostingClassifier<F>;
1546    type Error = FerroError;
1547
1548    /// Fit the histogram-based gradient boosting classifier.
1549    ///
1550    /// # Errors
1551    ///
1552    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
1553    /// numbers of samples.
1554    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
1555    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
1556    fn fit(
1557        &self,
1558        x: &Array2<F>,
1559        y: &Array1<usize>,
1560    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1561        let (n_samples, n_features) = x.dim();
1562
1563        if n_samples != y.len() {
1564            return Err(FerroError::ShapeMismatch {
1565                expected: vec![n_samples],
1566                actual: vec![y.len()],
1567                context: "y length must match number of samples in X".into(),
1568            });
1569        }
1570        if n_samples == 0 {
1571            return Err(FerroError::InsufficientSamples {
1572                required: 1,
1573                actual: 0,
1574                context: "HistGradientBoostingClassifier requires at least one sample".into(),
1575            });
1576        }
1577        if self.n_estimators == 0 {
1578            return Err(FerroError::InvalidParameter {
1579                name: "n_estimators".into(),
1580                reason: "must be at least 1".into(),
1581            });
1582        }
1583        if self.learning_rate <= 0.0 {
1584            return Err(FerroError::InvalidParameter {
1585                name: "learning_rate".into(),
1586                reason: "must be positive".into(),
1587            });
1588        }
1589        if self.max_bins < 2 {
1590            return Err(FerroError::InvalidParameter {
1591                name: "max_bins".into(),
1592                reason: "must be at least 2".into(),
1593            });
1594        }
1595
1596        // Determine unique classes.
1597        let mut classes: Vec<usize> = y.iter().copied().collect();
1598        classes.sort_unstable();
1599        classes.dedup();
1600        let n_classes = classes.len();
1601
1602        if n_classes < 2 {
1603            return Err(FerroError::InvalidParameter {
1604                name: "y".into(),
1605                reason: "need at least 2 distinct classes".into(),
1606            });
1607        }
1608
1609        let y_mapped: Vec<usize> = y
1610            .iter()
1611            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
1612            .collect();
1613
1614        let lr = F::from(self.learning_rate).unwrap();
1615        let l2_reg = F::from(self.l2_regularization).unwrap();
1616
1617        // Bin the data.
1618        let bin_infos = compute_bin_edges(x, self.max_bins);
1619        let binned = bin_data(x, &bin_infos);
1620
1621        let tree_params = HistTreeParams {
1622            max_depth: self.max_depth,
1623            min_samples_leaf: self.min_samples_leaf,
1624            max_leaf_nodes: self.max_leaf_nodes,
1625            l2_regularization: l2_reg,
1626        };
1627
1628        let all_indices: Vec<usize> = (0..n_samples).collect();
1629
1630        if n_classes == 2 {
1631            self.fit_binary(
1632                &binned,
1633                &y_mapped,
1634                n_samples,
1635                n_features,
1636                &classes,
1637                lr,
1638                &tree_params,
1639                &all_indices,
1640                &bin_infos,
1641            )
1642        } else {
1643            self.fit_multiclass(
1644                &binned,
1645                &y_mapped,
1646                n_samples,
1647                n_features,
1648                n_classes,
1649                &classes,
1650                lr,
1651                &tree_params,
1652                &all_indices,
1653                &bin_infos,
1654            )
1655        }
1656    }
1657}
1658
1659impl<F: Float + Send + Sync + 'static> HistGradientBoostingClassifier<F> {
1660    /// Fit binary classification (log-loss on log-odds).
1661    #[allow(clippy::too_many_arguments)]
1662    fn fit_binary(
1663        &self,
1664        binned: &[Vec<u16>],
1665        y_mapped: &[usize],
1666        n_samples: usize,
1667        n_features: usize,
1668        classes: &[usize],
1669        lr: F,
1670        tree_params: &HistTreeParams<F>,
1671        all_indices: &[usize],
1672        bin_infos: &[FeatureBinInfo<F>],
1673    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1674        // Initial log-odds.
1675        let pos_count = y_mapped.iter().filter(|&&c| c == 1).count();
1676        let p = F::from(pos_count).unwrap() / F::from(n_samples).unwrap();
1677        let eps = F::from(1e-15).unwrap();
1678        let p_clipped = p.max(eps).min(F::one() - eps);
1679        let init_val = (p_clipped / (F::one() - p_clipped)).ln();
1680
1681        let mut f_vals = Array1::from_elem(n_samples, init_val);
1682        let mut trees_seq: Vec<Vec<HistNode<F>>> = Vec::with_capacity(self.n_estimators);
1683
1684        for _ in 0..self.n_estimators {
1685            // Compute probabilities.
1686            let probs: Vec<F> = f_vals.iter().map(|&fv| sigmoid(fv)).collect();
1687
1688            // Gradients and hessians for log-loss:
1689            //   gradient = p - y (we negate pseudo-residual because tree fits -gradient)
1690            //   hessian = p * (1 - p)
1691            let gradients: Vec<F> = (0..n_samples)
1692                .map(|i| {
1693                    let yi = F::from(y_mapped[i]).unwrap();
1694                    probs[i] - yi
1695                })
1696                .collect();
1697            let hessians: Vec<F> = (0..n_samples)
1698                .map(|i| {
1699                    let pi = probs[i].max(eps).min(F::one() - eps);
1700                    pi * (F::one() - pi)
1701                })
1702                .collect();
1703
1704            let tree = build_hist_tree(
1705                binned,
1706                &gradients,
1707                &hessians,
1708                all_indices,
1709                bin_infos,
1710                tree_params,
1711            );
1712
1713            // Update f_vals.
1714            for i in 0..n_samples {
1715                let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1716                if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1717                    f_vals[i] = f_vals[i] + lr * value;
1718                }
1719            }
1720
1721            trees_seq.push(tree);
1722        }
1723
1724        // Feature importances.
1725        let mut total_importances = Array1::<F>::zeros(n_features);
1726        for tree_nodes in &trees_seq {
1727            total_importances =
1728                total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1729        }
1730        let imp_sum: F = total_importances
1731            .iter()
1732            .copied()
1733            .fold(F::zero(), |a, b| a + b);
1734        if imp_sum > F::zero() {
1735            total_importances.mapv_inplace(|v| v / imp_sum);
1736        }
1737
1738        Ok(FittedHistGradientBoostingClassifier {
1739            bin_infos: bin_infos.to_vec(),
1740            classes: classes.to_vec(),
1741            init: vec![init_val],
1742            learning_rate: lr,
1743            trees: vec![trees_seq],
1744            n_features,
1745            feature_importances: total_importances,
1746        })
1747    }
1748
1749    /// Fit multiclass classification (K trees per round, softmax).
1750    #[allow(clippy::too_many_arguments)]
1751    fn fit_multiclass(
1752        &self,
1753        binned: &[Vec<u16>],
1754        y_mapped: &[usize],
1755        n_samples: usize,
1756        n_features: usize,
1757        n_classes: usize,
1758        classes: &[usize],
1759        lr: F,
1760        tree_params: &HistTreeParams<F>,
1761        all_indices: &[usize],
1762        bin_infos: &[FeatureBinInfo<F>],
1763    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1764        // Initial log-prior for each class.
1765        let mut class_counts = vec![0usize; n_classes];
1766        for &c in y_mapped {
1767            class_counts[c] += 1;
1768        }
1769        let n_f = F::from(n_samples).unwrap();
1770        let eps = F::from(1e-15).unwrap();
1771        let init_vals: Vec<F> = class_counts
1772            .iter()
1773            .map(|&cnt| {
1774                let p = (F::from(cnt).unwrap() / n_f).max(eps);
1775                p.ln()
1776            })
1777            .collect();
1778
1779        let mut f_vals: Vec<Array1<F>> = init_vals
1780            .iter()
1781            .map(|&init| Array1::from_elem(n_samples, init))
1782            .collect();
1783
1784        let mut trees_per_class: Vec<Vec<Vec<HistNode<F>>>> = (0..n_classes)
1785            .map(|_| Vec::with_capacity(self.n_estimators))
1786            .collect();
1787
1788        for _ in 0..self.n_estimators {
1789            let probs = softmax_matrix(&f_vals, n_samples, n_classes);
1790
1791            for k in 0..n_classes {
1792                // Gradients and hessians for softmax cross-entropy:
1793                //   gradient_k = p_k - y_k
1794                //   hessian_k  = p_k * (1 - p_k)
1795                let gradients: Vec<F> = (0..n_samples)
1796                    .map(|i| {
1797                        let yi_k = if y_mapped[i] == k {
1798                            F::one()
1799                        } else {
1800                            F::zero()
1801                        };
1802                        probs[k][i] - yi_k
1803                    })
1804                    .collect();
1805                let hessians: Vec<F> = (0..n_samples)
1806                    .map(|i| {
1807                        let pk = probs[k][i].max(eps).min(F::one() - eps);
1808                        pk * (F::one() - pk)
1809                    })
1810                    .collect();
1811
1812                let tree = build_hist_tree(
1813                    binned,
1814                    &gradients,
1815                    &hessians,
1816                    all_indices,
1817                    bin_infos,
1818                    tree_params,
1819                );
1820
1821                // Update f_vals for class k.
1822                for (i, fv) in f_vals[k].iter_mut().enumerate() {
1823                    let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1824                    if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1825                        *fv = *fv + lr * value;
1826                    }
1827                }
1828
1829                trees_per_class[k].push(tree);
1830            }
1831        }
1832
1833        // Feature importances aggregated across all classes and rounds.
1834        let mut total_importances = Array1::<F>::zeros(n_features);
1835        for class_trees in &trees_per_class {
1836            for tree_nodes in class_trees {
1837                total_importances =
1838                    total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1839            }
1840        }
1841        let imp_sum: F = total_importances
1842            .iter()
1843            .copied()
1844            .fold(F::zero(), |a, b| a + b);
1845        if imp_sum > F::zero() {
1846            total_importances.mapv_inplace(|v| v / imp_sum);
1847        }
1848
1849        Ok(FittedHistGradientBoostingClassifier {
1850            bin_infos: bin_infos.to_vec(),
1851            classes: classes.to_vec(),
1852            init: init_vals,
1853            learning_rate: lr,
1854            trees: trees_per_class,
1855            n_features,
1856            feature_importances: total_importances,
1857        })
1858    }
1859}
1860
1861impl<F: Float + Send + Sync + 'static> Predict<Array2<F>>
1862    for FittedHistGradientBoostingClassifier<F>
1863{
1864    type Output = Array1<usize>;
1865    type Error = FerroError;
1866
1867    /// Predict class labels.
1868    ///
1869    /// # Errors
1870    ///
1871    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
1872    /// not match the fitted model.
1873    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
1874        if x.ncols() != self.n_features {
1875            return Err(FerroError::ShapeMismatch {
1876                expected: vec![self.n_features],
1877                actual: vec![x.ncols()],
1878                context: "number of features must match fitted model".into(),
1879            });
1880        }
1881
1882        let n_samples = x.nrows();
1883        let n_classes = self.classes.len();
1884        let binned = bin_data(x, &self.bin_infos);
1885
1886        if n_classes == 2 {
1887            let init = self.init[0];
1888            let mut predictions = Array1::zeros(n_samples);
1889            for i in 0..n_samples {
1890                let mut f_val = init;
1891                for tree_nodes in &self.trees[0] {
1892                    let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1893                    if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1894                        f_val = f_val + self.learning_rate * value;
1895                    }
1896                }
1897                let prob = sigmoid(f_val);
1898                let class_idx = if prob >= F::from(0.5).unwrap() { 1 } else { 0 };
1899                predictions[i] = self.classes[class_idx];
1900            }
1901            Ok(predictions)
1902        } else {
1903            let mut predictions = Array1::zeros(n_samples);
1904            for i in 0..n_samples {
1905                let mut scores = Vec::with_capacity(n_classes);
1906                for k in 0..n_classes {
1907                    let mut f_val = self.init[k];
1908                    for tree_nodes in &self.trees[k] {
1909                        let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1910                        if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1911                            f_val = f_val + self.learning_rate * value;
1912                        }
1913                    }
1914                    scores.push(f_val);
1915                }
1916                let best_k = scores
1917                    .iter()
1918                    .enumerate()
1919                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1920                    .map_or(0, |(k, _)| k);
1921                predictions[i] = self.classes[best_k];
1922            }
1923            Ok(predictions)
1924        }
1925    }
1926}
1927
1928impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1929    for FittedHistGradientBoostingClassifier<F>
1930{
1931    fn feature_importances(&self) -> &Array1<F> {
1932        &self.feature_importances
1933    }
1934}
1935
1936impl<F: Float + Send + Sync + 'static> HasClasses for FittedHistGradientBoostingClassifier<F> {
1937    fn classes(&self) -> &[usize] {
1938        &self.classes
1939    }
1940
1941    fn n_classes(&self) -> usize {
1942        self.classes.len()
1943    }
1944}
1945
1946impl<F: Float + Send + Sync + 'static> FittedHistGradientBoostingClassifier<F> {
1947    /// Mean accuracy on the given test data and labels.
1948    /// Equivalent to sklearn's `ClassifierMixin.score`.
1949    ///
1950    /// # Errors
1951    ///
1952    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
1953    /// the feature count does not match the training data.
1954    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
1955        if x.nrows() != y.len() {
1956            return Err(FerroError::ShapeMismatch {
1957                expected: vec![x.nrows()],
1958                actual: vec![y.len()],
1959                context: "y length must match number of samples in X".into(),
1960            });
1961        }
1962        let preds = self.predict(x)?;
1963        Ok(crate::mean_accuracy(&preds, y))
1964    }
1965
1966    /// Predict class probabilities. Mirrors sklearn's
1967    /// `HistGradientBoostingClassifier.predict_proba`.
1968    ///
1969    /// Binary: logistic link of the cumulative log-odds.
1970    /// Multiclass: softmax over K cumulative scores.
1971    ///
1972    /// Returns shape `(n_samples, n_classes)`; rows sum to 1.
1973    ///
1974    /// # Errors
1975    ///
1976    /// Returns [`FerroError::ShapeMismatch`] if the number of features
1977    /// does not match the fitted model.
1978    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
1979        if x.ncols() != self.n_features {
1980            return Err(FerroError::ShapeMismatch {
1981                expected: vec![self.n_features],
1982                actual: vec![x.ncols()],
1983                context: "number of features must match fitted model".into(),
1984            });
1985        }
1986        let n_samples = x.nrows();
1987        let n_classes = self.classes.len();
1988        let binned = bin_data(x, &self.bin_infos);
1989        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
1990
1991        if n_classes == 2 {
1992            let init = self.init[0];
1993            for i in 0..n_samples {
1994                let mut f_val = init;
1995                for tree_nodes in &self.trees[0] {
1996                    let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1997                    if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1998                        f_val = f_val + self.learning_rate * value;
1999                    }
2000                }
2001                let p1 = sigmoid(f_val);
2002                proba[[i, 0]] = F::one() - p1;
2003                proba[[i, 1]] = p1;
2004            }
2005        } else {
2006            for i in 0..n_samples {
2007                let mut scores = vec![F::zero(); n_classes];
2008                for k in 0..n_classes {
2009                    let mut f_val = self.init[k];
2010                    for tree_nodes in &self.trees[k] {
2011                        let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
2012                        if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
2013                            f_val = f_val + self.learning_rate * value;
2014                        }
2015                    }
2016                    scores[k] = f_val;
2017                }
2018                let max_s = scores
2019                    .iter()
2020                    .copied()
2021                    .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
2022                let mut sum_exp = F::zero();
2023                for k in 0..n_classes {
2024                    let e = (scores[k] - max_s).exp();
2025                    proba[[i, k]] = e;
2026                    sum_exp = sum_exp + e;
2027                }
2028                if sum_exp > F::zero() {
2029                    for k in 0..n_classes {
2030                        proba[[i, k]] = proba[[i, k]] / sum_exp;
2031                    }
2032                }
2033            }
2034        }
2035        Ok(proba)
2036    }
2037
2038    /// Element-wise log of [`predict_proba`](Self::predict_proba). Mirrors
2039    /// sklearn's `ClassifierMixin.predict_log_proba`.
2040    ///
2041    /// # Errors
2042    ///
2043    /// Forwards any error from [`predict_proba`](Self::predict_proba).
2044    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
2045        let proba = self.predict_proba(x)?;
2046        Ok(crate::log_proba(&proba))
2047    }
2048
2049    /// Cumulative raw scores per sample (pre-link). Mirrors sklearn's
2050    /// `HistGradientBoostingClassifier.decision_function`.
2051    ///
2052    /// Binary: shape `(n_samples, 1)`. Multiclass: shape
2053    /// `(n_samples, n_classes)`.
2054    ///
2055    /// # Errors
2056    ///
2057    /// Returns [`FerroError::ShapeMismatch`] if the number of features
2058    /// does not match the fitted model.
2059    pub fn decision_function(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
2060        if x.ncols() != self.n_features {
2061            return Err(FerroError::ShapeMismatch {
2062                expected: vec![self.n_features],
2063                actual: vec![x.ncols()],
2064                context: "number of features must match fitted model".into(),
2065            });
2066        }
2067        let n_samples = x.nrows();
2068        let n_classes = self.classes.len();
2069        let binned = bin_data(x, &self.bin_infos);
2070
2071        if n_classes == 2 {
2072            let init = self.init[0];
2073            let mut out = Array2::<F>::zeros((n_samples, 1));
2074            for i in 0..n_samples {
2075                let mut f_val = init;
2076                for tree_nodes in &self.trees[0] {
2077                    let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
2078                    if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
2079                        f_val = f_val + self.learning_rate * value;
2080                    }
2081                }
2082                out[[i, 0]] = f_val;
2083            }
2084            Ok(out)
2085        } else {
2086            let mut out = Array2::<F>::zeros((n_samples, n_classes));
2087            for i in 0..n_samples {
2088                for k in 0..n_classes {
2089                    let mut f_val = self.init[k];
2090                    for tree_nodes in &self.trees[k] {
2091                        let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
2092                        if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
2093                            f_val = f_val + self.learning_rate * value;
2094                        }
2095                    }
2096                    out[[i, k]] = f_val;
2097                }
2098            }
2099            Ok(out)
2100        }
2101    }
2102}
2103
2104// Pipeline integration.
2105impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
2106    for HistGradientBoostingClassifier<F>
2107{
2108    fn fit_pipeline(
2109        &self,
2110        x: &Array2<F>,
2111        y: &Array1<F>,
2112    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
2113        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
2114        let fitted = self.fit(x, &y_usize)?;
2115        Ok(Box::new(FittedHgbcPipelineAdapter(fitted)))
2116    }
2117}
2118
2119/// Pipeline adapter for `FittedHistGradientBoostingClassifier<F>`.
2120struct FittedHgbcPipelineAdapter<F: Float + Send + Sync + 'static>(
2121    FittedHistGradientBoostingClassifier<F>,
2122);
2123
2124impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
2125    for FittedHgbcPipelineAdapter<F>
2126{
2127    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
2128        let preds = self.0.predict(x)?;
2129        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
2130    }
2131}
2132
2133// ---------------------------------------------------------------------------
2134// Internal helpers
2135// ---------------------------------------------------------------------------
2136
2137/// Compute the median of an `Array1`.
2138fn median_f<F: Float>(arr: &Array1<F>) -> F {
2139    let mut sorted: Vec<F> = arr.iter().copied().collect();
2140    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
2141    let n = sorted.len();
2142    if n == 0 {
2143        return F::zero();
2144    }
2145    if n % 2 == 1 {
2146        sorted[n / 2]
2147    } else {
2148        (sorted[n / 2 - 1] + sorted[n / 2]) / F::from(2.0).unwrap()
2149    }
2150}
2151
2152// ---------------------------------------------------------------------------
2153// Tests
2154// ---------------------------------------------------------------------------
2155
2156#[cfg(test)]
2157mod tests {
2158    use super::*;
2159    use approx::assert_relative_eq;
2160    use ndarray::array;
2161
2162    // -- Binning tests --
2163
2164    #[test]
2165    fn test_bin_edges_simple() {
2166        let x =
2167            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2168        let infos = compute_bin_edges(&x, 4);
2169        assert_eq!(infos.len(), 1);
2170        // Should have up to 4 bins.
2171        assert!(infos[0].n_bins <= 4);
2172        assert!(infos[0].n_bins >= 2);
2173        assert!(!infos[0].has_nan);
2174    }
2175
2176    #[test]
2177    fn test_bin_edges_with_nan() {
2178        let x = Array2::from_shape_vec((5, 1), vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0]).unwrap();
2179        let infos = compute_bin_edges(&x, 4);
2180        assert_eq!(infos.len(), 1);
2181        assert!(infos[0].has_nan);
2182        assert!(infos[0].n_bins >= 1);
2183    }
2184
2185    #[test]
2186    fn test_bin_edges_all_nan() {
2187        let x = Array2::from_shape_vec((3, 1), vec![f64::NAN, f64::NAN, f64::NAN]).unwrap();
2188        let infos = compute_bin_edges(&x, 4);
2189        assert_eq!(infos[0].n_bins, 0);
2190        assert!(infos[0].has_nan);
2191    }
2192
2193    #[test]
2194    fn test_map_to_bin_basic() {
2195        let info = FeatureBinInfo {
2196            edges: vec![2.0, 4.0, 6.0, 8.0],
2197            n_bins: 4,
2198            has_nan: false,
2199        };
2200        assert_eq!(map_to_bin(1.0, &info), 0);
2201        assert_eq!(map_to_bin(2.0, &info), 0);
2202        assert_eq!(map_to_bin(3.0, &info), 1);
2203        assert_eq!(map_to_bin(5.0, &info), 2);
2204        assert_eq!(map_to_bin(7.0, &info), 3);
2205        assert_eq!(map_to_bin(9.0, &info), 3);
2206        assert_eq!(map_to_bin(f64::NAN, &info), NAN_BIN);
2207    }
2208
2209    #[test]
2210    fn test_bin_data_roundtrip() {
2211        let x = Array2::from_shape_vec((4, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0])
2212            .unwrap();
2213        let infos = compute_bin_edges(&x, 255);
2214        let binned = bin_data(&x, &infos);
2215        assert_eq!(binned.len(), 4);
2216        assert_eq!(binned[0].len(), 2);
2217        // Values should be monotonically non-decreasing since input is sorted.
2218        for window in binned.windows(2) {
2219            let (prev, curr) = (&window[0], &window[1]);
2220            for (&p, &c) in prev.iter().zip(curr.iter()) {
2221                assert!(c >= p);
2222            }
2223        }
2224    }
2225
2226    // -- Subtraction trick test --
2227
2228    #[test]
2229    fn test_subtraction_trick() {
2230        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2231        let bin_infos = compute_bin_edges(&x, 255);
2232        let binned = bin_data(&x, &bin_infos);
2233
2234        let gradients = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
2235        let hessians = vec![1.0; 6];
2236        let all_indices: Vec<usize> = (0..6).collect();
2237        let left_indices: Vec<usize> = vec![0, 1, 2];
2238        let right_indices: Vec<usize> = vec![3, 4, 5];
2239
2240        let parent_hist =
2241            build_histograms(&binned, &gradients, &hessians, &all_indices, &bin_infos);
2242        let left_hist = build_histograms(&binned, &gradients, &hessians, &left_indices, &bin_infos);
2243        let right_from_sub = subtract_histograms(&parent_hist, &left_hist, &bin_infos);
2244        let right_direct =
2245            build_histograms(&binned, &gradients, &hessians, &right_indices, &bin_infos);
2246
2247        // The subtraction trick result should match direct computation.
2248        for j in 0..bin_infos.len() {
2249            let n_bins = bin_infos[j].n_bins as usize;
2250            for b in 0..n_bins {
2251                assert_relative_eq!(
2252                    right_from_sub[j].bins[b].grad_sum,
2253                    right_direct[j].bins[b].grad_sum,
2254                    epsilon = 1e-10
2255                );
2256                assert_relative_eq!(
2257                    right_from_sub[j].bins[b].hess_sum,
2258                    right_direct[j].bins[b].hess_sum,
2259                    epsilon = 1e-10
2260                );
2261                assert_eq!(
2262                    right_from_sub[j].bins[b].count,
2263                    right_direct[j].bins[b].count
2264                );
2265            }
2266        }
2267    }
2268
2269    // -- Regressor tests --
2270
2271    #[test]
2272    fn test_hgbr_simple_least_squares() {
2273        let x =
2274            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2275        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2276
2277        let model = HistGradientBoostingRegressor::<f64>::new()
2278            .with_n_estimators(50)
2279            .with_learning_rate(0.1)
2280            .with_min_samples_leaf(1)
2281            .with_max_leaf_nodes(None)
2282            .with_max_depth(Some(3))
2283            .with_random_state(42);
2284        let fitted = model.fit(&x, &y).unwrap();
2285        let preds = fitted.predict(&x).unwrap();
2286
2287        assert_eq!(preds.len(), 8);
2288        for i in 0..4 {
2289            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
2290        }
2291        for i in 4..8 {
2292            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
2293        }
2294    }
2295
2296    #[test]
2297    fn test_hgbr_lad_loss() {
2298        let x =
2299            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2300        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2301
2302        let model = HistGradientBoostingRegressor::<f64>::new()
2303            .with_n_estimators(50)
2304            .with_loss(HistRegressionLoss::LeastAbsoluteDeviation)
2305            .with_min_samples_leaf(1)
2306            .with_max_leaf_nodes(None)
2307            .with_max_depth(Some(3))
2308            .with_random_state(42);
2309        let fitted = model.fit(&x, &y).unwrap();
2310        let preds = fitted.predict(&x).unwrap();
2311
2312        assert_eq!(preds.len(), 8);
2313        for i in 0..4 {
2314            assert!(preds[i] < 3.5, "LAD expected <3.5, got {}", preds[i]);
2315        }
2316        for i in 4..8 {
2317            assert!(preds[i] > 2.5, "LAD expected >2.5, got {}", preds[i]);
2318        }
2319    }
2320
2321    #[test]
2322    fn test_hgbr_reproducibility() {
2323        let x =
2324            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2325        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2326
2327        let model = HistGradientBoostingRegressor::<f64>::new()
2328            .with_n_estimators(20)
2329            .with_min_samples_leaf(1)
2330            .with_max_leaf_nodes(None)
2331            .with_max_depth(Some(3))
2332            .with_random_state(123);
2333
2334        let fitted1 = model.fit(&x, &y).unwrap();
2335        let fitted2 = model.fit(&x, &y).unwrap();
2336
2337        let preds1 = fitted1.predict(&x).unwrap();
2338        let preds2 = fitted2.predict(&x).unwrap();
2339
2340        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
2341            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
2342        }
2343    }
2344
2345    #[test]
2346    fn test_hgbr_feature_importances() {
2347        let x = Array2::from_shape_vec(
2348            (10, 3),
2349            vec![
2350                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
2351                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
2352            ],
2353        )
2354        .unwrap();
2355        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
2356
2357        let model = HistGradientBoostingRegressor::<f64>::new()
2358            .with_n_estimators(20)
2359            .with_min_samples_leaf(1)
2360            .with_max_leaf_nodes(None)
2361            .with_max_depth(Some(3))
2362            .with_random_state(42);
2363        let fitted = model.fit(&x, &y).unwrap();
2364        let importances = fitted.feature_importances();
2365
2366        assert_eq!(importances.len(), 3);
2367        // First feature should be most important since it's the only one with variance.
2368        assert!(
2369            importances[0] > importances[1],
2370            "Expected imp[0]={} > imp[1]={}",
2371            importances[0],
2372            importances[1]
2373        );
2374    }
2375
2376    #[test]
2377    fn test_hgbr_shape_mismatch_fit() {
2378        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2379        let y = array![1.0, 2.0];
2380
2381        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(5);
2382        assert!(model.fit(&x, &y).is_err());
2383    }
2384
2385    #[test]
2386    fn test_hgbr_shape_mismatch_predict() {
2387        let x =
2388            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2389        let y = array![1.0, 2.0, 3.0, 4.0];
2390
2391        let model = HistGradientBoostingRegressor::<f64>::new()
2392            .with_n_estimators(5)
2393            .with_min_samples_leaf(1)
2394            .with_max_leaf_nodes(None)
2395            .with_max_depth(Some(3))
2396            .with_random_state(0);
2397        let fitted = model.fit(&x, &y).unwrap();
2398
2399        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2400        assert!(fitted.predict(&x_bad).is_err());
2401    }
2402
2403    #[test]
2404    fn test_hgbr_empty_data() {
2405        let x = Array2::<f64>::zeros((0, 2));
2406        let y = Array1::<f64>::zeros(0);
2407
2408        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(5);
2409        assert!(model.fit(&x, &y).is_err());
2410    }
2411
2412    #[test]
2413    fn test_hgbr_zero_estimators() {
2414        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2415        let y = array![1.0, 2.0, 3.0, 4.0];
2416
2417        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(0);
2418        assert!(model.fit(&x, &y).is_err());
2419    }
2420
2421    #[test]
2422    fn test_hgbr_invalid_learning_rate() {
2423        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2424        let y = array![1.0, 2.0, 3.0, 4.0];
2425
2426        let model = HistGradientBoostingRegressor::<f64>::new()
2427            .with_n_estimators(5)
2428            .with_learning_rate(0.0);
2429        assert!(model.fit(&x, &y).is_err());
2430    }
2431
2432    #[test]
2433    fn test_hgbr_invalid_max_bins() {
2434        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2435        let y = array![1.0, 2.0, 3.0, 4.0];
2436
2437        let model = HistGradientBoostingRegressor::<f64>::new()
2438            .with_n_estimators(5)
2439            .with_max_bins(1);
2440        assert!(model.fit(&x, &y).is_err());
2441    }
2442
2443    #[test]
2444    fn test_hgbr_default_trait() {
2445        let model = HistGradientBoostingRegressor::<f64>::default();
2446        assert_eq!(model.n_estimators, 100);
2447        assert!((model.learning_rate - 0.1).abs() < 1e-10);
2448        assert_eq!(model.max_bins, 255);
2449    }
2450
2451    #[test]
2452    fn test_hgbr_pipeline_integration() {
2453        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2454        let y = array![1.0, 2.0, 3.0, 4.0];
2455
2456        let model = HistGradientBoostingRegressor::<f64>::new()
2457            .with_n_estimators(10)
2458            .with_min_samples_leaf(1)
2459            .with_max_leaf_nodes(None)
2460            .with_max_depth(Some(3))
2461            .with_random_state(42);
2462        let fitted = model.fit_pipeline(&x, &y).unwrap();
2463        let preds = fitted.predict_pipeline(&x).unwrap();
2464        assert_eq!(preds.len(), 4);
2465    }
2466
2467    #[test]
2468    fn test_hgbr_f32_support() {
2469        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
2470        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
2471
2472        let model = HistGradientBoostingRegressor::<f32>::new()
2473            .with_n_estimators(10)
2474            .with_min_samples_leaf(1)
2475            .with_max_leaf_nodes(None)
2476            .with_max_depth(Some(3))
2477            .with_random_state(42);
2478        let fitted = model.fit(&x, &y).unwrap();
2479        let preds = fitted.predict(&x).unwrap();
2480        assert_eq!(preds.len(), 4);
2481    }
2482
2483    #[test]
2484    fn test_hgbr_nan_handling() {
2485        // Some features have NaN — the model should handle them gracefully.
2486        let x = Array2::from_shape_vec(
2487            (8, 1),
2488            vec![1.0, f64::NAN, 3.0, 4.0, 5.0, f64::NAN, 7.0, 8.0],
2489        )
2490        .unwrap();
2491        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2492
2493        let model = HistGradientBoostingRegressor::<f64>::new()
2494            .with_n_estimators(20)
2495            .with_min_samples_leaf(1)
2496            .with_max_leaf_nodes(None)
2497            .with_max_depth(Some(3))
2498            .with_random_state(42);
2499        let fitted = model.fit(&x, &y).unwrap();
2500        let preds = fitted.predict(&x).unwrap();
2501        assert_eq!(preds.len(), 8);
2502        // Should still produce finite predictions for all samples (including NaN inputs).
2503        for p in &preds {
2504            assert!(p.is_finite(), "Expected finite prediction, got {p}");
2505        }
2506    }
2507
2508    #[test]
2509    fn test_hgbr_convergence() {
2510        // MSE should decrease as we add more estimators.
2511        let x =
2512            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2513        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2514
2515        let mse = |preds: &Array1<f64>, y: &Array1<f64>| -> f64 {
2516            preds
2517                .iter()
2518                .zip(y.iter())
2519                .map(|(p, t)| (p - t).powi(2))
2520                .sum::<f64>()
2521                / y.len() as f64
2522        };
2523
2524        let model_few = HistGradientBoostingRegressor::<f64>::new()
2525            .with_n_estimators(5)
2526            .with_min_samples_leaf(1)
2527            .with_max_leaf_nodes(None)
2528            .with_max_depth(Some(3))
2529            .with_random_state(42);
2530        let fitted_few = model_few.fit(&x, &y).unwrap();
2531        let preds_few = fitted_few.predict(&x).unwrap();
2532        let mse_few = mse(&preds_few, &y);
2533
2534        let model_many = HistGradientBoostingRegressor::<f64>::new()
2535            .with_n_estimators(50)
2536            .with_min_samples_leaf(1)
2537            .with_max_leaf_nodes(None)
2538            .with_max_depth(Some(3))
2539            .with_random_state(42);
2540        let fitted_many = model_many.fit(&x, &y).unwrap();
2541        let preds_many = fitted_many.predict(&x).unwrap();
2542        let mse_many = mse(&preds_many, &y);
2543
2544        assert!(
2545            mse_many < mse_few,
2546            "Expected MSE to decrease with more estimators: {mse_many} (50) vs {mse_few} (5)"
2547        );
2548    }
2549
2550    #[test]
2551    fn test_hgbr_max_leaf_nodes() {
2552        // Test that best-first growth with max_leaf_nodes works.
2553        let x =
2554            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2555        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2556
2557        let model = HistGradientBoostingRegressor::<f64>::new()
2558            .with_n_estimators(20)
2559            .with_min_samples_leaf(1)
2560            .with_max_leaf_nodes(Some(4))
2561            .with_random_state(42);
2562        let fitted = model.fit(&x, &y).unwrap();
2563        let preds = fitted.predict(&x).unwrap();
2564        assert_eq!(preds.len(), 8);
2565    }
2566
2567    #[test]
2568    fn test_hgbr_l2_regularization() {
2569        let x =
2570            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2571        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2572
2573        // With very high regularization, predictions should be closer to the mean.
2574        let model_noreg = HistGradientBoostingRegressor::<f64>::new()
2575            .with_n_estimators(20)
2576            .with_min_samples_leaf(1)
2577            .with_max_leaf_nodes(None)
2578            .with_max_depth(Some(3))
2579            .with_l2_regularization(0.0)
2580            .with_random_state(42);
2581        let fitted_noreg = model_noreg.fit(&x, &y).unwrap();
2582        let preds_noreg = fitted_noreg.predict(&x).unwrap();
2583
2584        let model_highreg = HistGradientBoostingRegressor::<f64>::new()
2585            .with_n_estimators(20)
2586            .with_min_samples_leaf(1)
2587            .with_max_leaf_nodes(None)
2588            .with_max_depth(Some(3))
2589            .with_l2_regularization(100.0)
2590            .with_random_state(42);
2591        let fitted_highreg = model_highreg.fit(&x, &y).unwrap();
2592        let preds_highreg = fitted_highreg.predict(&x).unwrap();
2593
2594        // With high reg, variance of predictions should be smaller.
2595        let var = |preds: &Array1<f64>| -> f64 {
2596            let mean = preds.mean().unwrap();
2597            preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64
2598        };
2599
2600        assert!(
2601            var(&preds_highreg) < var(&preds_noreg),
2602            "High regularization should reduce prediction variance"
2603        );
2604    }
2605
2606    // -- Classifier tests --
2607
2608    #[test]
2609    fn test_hgbc_binary_simple() {
2610        let x = Array2::from_shape_vec(
2611            (8, 2),
2612            vec![
2613                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
2614            ],
2615        )
2616        .unwrap();
2617        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2618
2619        let model = HistGradientBoostingClassifier::<f64>::new()
2620            .with_n_estimators(50)
2621            .with_learning_rate(0.1)
2622            .with_min_samples_leaf(1)
2623            .with_max_leaf_nodes(None)
2624            .with_max_depth(Some(3))
2625            .with_random_state(42);
2626        let fitted = model.fit(&x, &y).unwrap();
2627        let preds = fitted.predict(&x).unwrap();
2628
2629        assert_eq!(preds.len(), 8);
2630        for i in 0..4 {
2631            assert_eq!(preds[i], 0, "Expected 0 at index {}, got {}", i, preds[i]);
2632        }
2633        for i in 4..8 {
2634            assert_eq!(preds[i], 1, "Expected 1 at index {}, got {}", i, preds[i]);
2635        }
2636    }
2637
2638    #[test]
2639    fn test_hgbc_multiclass() {
2640        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
2641            .unwrap();
2642        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
2643
2644        let model = HistGradientBoostingClassifier::<f64>::new()
2645            .with_n_estimators(50)
2646            .with_learning_rate(0.1)
2647            .with_min_samples_leaf(1)
2648            .with_max_leaf_nodes(None)
2649            .with_max_depth(Some(3))
2650            .with_random_state(42);
2651        let fitted = model.fit(&x, &y).unwrap();
2652        let preds = fitted.predict(&x).unwrap();
2653
2654        assert_eq!(preds.len(), 9);
2655        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
2656        assert!(
2657            correct >= 6,
2658            "Expected at least 6/9 correct, got {correct}/9"
2659        );
2660    }
2661
2662    #[test]
2663    fn test_hgbc_has_classes() {
2664        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2665        let y = array![0, 1, 2, 0, 1, 2];
2666
2667        let model = HistGradientBoostingClassifier::<f64>::new()
2668            .with_n_estimators(5)
2669            .with_min_samples_leaf(1)
2670            .with_max_leaf_nodes(None)
2671            .with_max_depth(Some(3))
2672            .with_random_state(0);
2673        let fitted = model.fit(&x, &y).unwrap();
2674
2675        assert_eq!(fitted.classes(), &[0, 1, 2]);
2676        assert_eq!(fitted.n_classes(), 3);
2677    }
2678
2679    #[test]
2680    fn test_hgbc_reproducibility() {
2681        let x = Array2::from_shape_vec(
2682            (8, 2),
2683            vec![
2684                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
2685            ],
2686        )
2687        .unwrap();
2688        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2689
2690        let model = HistGradientBoostingClassifier::<f64>::new()
2691            .with_n_estimators(10)
2692            .with_min_samples_leaf(1)
2693            .with_max_leaf_nodes(None)
2694            .with_max_depth(Some(3))
2695            .with_random_state(42);
2696
2697        let fitted1 = model.fit(&x, &y).unwrap();
2698        let fitted2 = model.fit(&x, &y).unwrap();
2699
2700        let preds1 = fitted1.predict(&x).unwrap();
2701        let preds2 = fitted2.predict(&x).unwrap();
2702        assert_eq!(preds1, preds2);
2703    }
2704
2705    #[test]
2706    fn test_hgbc_feature_importances() {
2707        let x = Array2::from_shape_vec(
2708            (10, 3),
2709            vec![
2710                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
2711                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
2712            ],
2713        )
2714        .unwrap();
2715        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
2716
2717        let model = HistGradientBoostingClassifier::<f64>::new()
2718            .with_n_estimators(20)
2719            .with_min_samples_leaf(1)
2720            .with_max_leaf_nodes(None)
2721            .with_max_depth(Some(3))
2722            .with_random_state(42);
2723        let fitted = model.fit(&x, &y).unwrap();
2724        let importances = fitted.feature_importances();
2725
2726        assert_eq!(importances.len(), 3);
2727        assert!(importances[0] > importances[1]);
2728    }
2729
2730    #[test]
2731    fn test_hgbc_shape_mismatch_fit() {
2732        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2733        let y = array![0, 1];
2734
2735        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2736        assert!(model.fit(&x, &y).is_err());
2737    }
2738
2739    #[test]
2740    fn test_hgbc_shape_mismatch_predict() {
2741        let x =
2742            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2743        let y = array![0, 0, 1, 1];
2744
2745        let model = HistGradientBoostingClassifier::<f64>::new()
2746            .with_n_estimators(5)
2747            .with_min_samples_leaf(1)
2748            .with_max_leaf_nodes(None)
2749            .with_max_depth(Some(3))
2750            .with_random_state(0);
2751        let fitted = model.fit(&x, &y).unwrap();
2752
2753        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2754        assert!(fitted.predict(&x_bad).is_err());
2755    }
2756
2757    #[test]
2758    fn test_hgbc_empty_data() {
2759        let x = Array2::<f64>::zeros((0, 2));
2760        let y = Array1::<usize>::zeros(0);
2761
2762        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2763        assert!(model.fit(&x, &y).is_err());
2764    }
2765
2766    #[test]
2767    fn test_hgbc_single_class() {
2768        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
2769        let y = array![0, 0, 0];
2770
2771        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2772        assert!(model.fit(&x, &y).is_err());
2773    }
2774
2775    #[test]
2776    fn test_hgbc_zero_estimators() {
2777        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2778        let y = array![0, 0, 1, 1];
2779
2780        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(0);
2781        assert!(model.fit(&x, &y).is_err());
2782    }
2783
2784    #[test]
2785    fn test_hgbc_pipeline_integration() {
2786        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2787        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
2788
2789        let model = HistGradientBoostingClassifier::<f64>::new()
2790            .with_n_estimators(10)
2791            .with_min_samples_leaf(1)
2792            .with_max_leaf_nodes(None)
2793            .with_max_depth(Some(3))
2794            .with_random_state(42);
2795        let fitted = model.fit_pipeline(&x, &y).unwrap();
2796        let preds = fitted.predict_pipeline(&x).unwrap();
2797        assert_eq!(preds.len(), 6);
2798    }
2799
2800    #[test]
2801    fn test_hgbc_f32_support() {
2802        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2803        let y = array![0, 0, 0, 1, 1, 1];
2804
2805        let model = HistGradientBoostingClassifier::<f32>::new()
2806            .with_n_estimators(10)
2807            .with_min_samples_leaf(1)
2808            .with_max_leaf_nodes(None)
2809            .with_max_depth(Some(3))
2810            .with_random_state(42);
2811        let fitted = model.fit(&x, &y).unwrap();
2812        let preds = fitted.predict(&x).unwrap();
2813        assert_eq!(preds.len(), 6);
2814    }
2815
2816    #[test]
2817    fn test_hgbc_default_trait() {
2818        let model = HistGradientBoostingClassifier::<f64>::default();
2819        assert_eq!(model.n_estimators, 100);
2820        assert!((model.learning_rate - 0.1).abs() < 1e-10);
2821        assert_eq!(model.max_bins, 255);
2822    }
2823
2824    #[test]
2825    fn test_hgbc_non_contiguous_labels() {
2826        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2827        let y = array![10, 10, 10, 20, 20, 20];
2828
2829        let model = HistGradientBoostingClassifier::<f64>::new()
2830            .with_n_estimators(20)
2831            .with_min_samples_leaf(1)
2832            .with_max_leaf_nodes(None)
2833            .with_max_depth(Some(3))
2834            .with_random_state(42);
2835        let fitted = model.fit(&x, &y).unwrap();
2836        let preds = fitted.predict(&x).unwrap();
2837
2838        assert_eq!(preds.len(), 6);
2839        for &p in &preds {
2840            assert!(p == 10 || p == 20);
2841        }
2842    }
2843
2844    #[test]
2845    fn test_hgbc_nan_handling() {
2846        // Classifier should handle NaN features.
2847        let x = Array2::from_shape_vec(
2848            (8, 2),
2849            vec![
2850                1.0,
2851                f64::NAN,
2852                2.0,
2853                3.0,
2854                f64::NAN,
2855                3.0,
2856                4.0,
2857                4.0,
2858                5.0,
2859                6.0,
2860                6.0,
2861                f64::NAN,
2862                7.0,
2863                8.0,
2864                f64::NAN,
2865                9.0,
2866            ],
2867        )
2868        .unwrap();
2869        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2870
2871        let model = HistGradientBoostingClassifier::<f64>::new()
2872            .with_n_estimators(20)
2873            .with_min_samples_leaf(1)
2874            .with_max_leaf_nodes(None)
2875            .with_max_depth(Some(3))
2876            .with_random_state(42);
2877        let fitted = model.fit(&x, &y).unwrap();
2878        let preds = fitted.predict(&x).unwrap();
2879        assert_eq!(preds.len(), 8);
2880        // All predictions should be valid class labels.
2881        for &p in &preds {
2882            assert!(p == 0 || p == 1);
2883        }
2884    }
2885
2886    // -- Comparison with standard GBM --
2887
2888    #[test]
2889    fn test_hist_vs_standard_gbm_similar_accuracy() {
2890        // Both models should achieve comparable accuracy on a simple task.
2891        use crate::GradientBoostingRegressor;
2892
2893        let x = Array2::from_shape_vec(
2894            (10, 1),
2895            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
2896        )
2897        .unwrap();
2898        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
2899
2900        let mse = |preds: &Array1<f64>, y: &Array1<f64>| -> f64 {
2901            preds
2902                .iter()
2903                .zip(y.iter())
2904                .map(|(p, t)| (p - t).powi(2))
2905                .sum::<f64>()
2906                / y.len() as f64
2907        };
2908
2909        // Standard GBM.
2910        let std_model = GradientBoostingRegressor::<f64>::new()
2911            .with_n_estimators(50)
2912            .with_learning_rate(0.1)
2913            .with_max_depth(Some(3))
2914            .with_random_state(42);
2915        let std_fitted = std_model.fit(&x, &y).unwrap();
2916        let std_preds = std_fitted.predict(&x).unwrap();
2917        let std_mse = mse(&std_preds, &y);
2918
2919        // Histogram GBM.
2920        let hist_model = HistGradientBoostingRegressor::<f64>::new()
2921            .with_n_estimators(50)
2922            .with_learning_rate(0.1)
2923            .with_min_samples_leaf(1)
2924            .with_max_leaf_nodes(None)
2925            .with_max_depth(Some(3))
2926            .with_random_state(42);
2927        let hist_fitted = hist_model.fit(&x, &y).unwrap();
2928        let hist_preds = hist_fitted.predict(&x).unwrap();
2929        let hist_mse = mse(&hist_preds, &y);
2930
2931        // Both should have low MSE on this simple task.
2932        assert!(std_mse < 1.0, "Standard GBM MSE too high: {std_mse}");
2933        assert!(hist_mse < 1.0, "Hist GBM MSE too high: {hist_mse}");
2934    }
2935}