Skip to main content

ferrolearn_tree/
hist_gradient_boosting.rs

1//! Histogram-based gradient boosting classifiers and regressors.
2//!
3//! This module provides [`HistGradientBoostingClassifier`] and
4//! [`HistGradientBoostingRegressor`], which implement histogram-based gradient
5//! boosting trees inspired by LightGBM / scikit-learn's HistGradientBoosting*.
6//!
7//! # Key design choices
8//!
9//! - **Feature binning**: continuous features are discretised into up to
10//!   `max_bins` (default 256) bins using quantile-based bin edges. NaN values
11//!   are assigned a dedicated bin.
12//! - **Histogram-based split finding**: at each node, gradient/hessian sums are
13//!   accumulated into per-bin histograms, making split finding O(n_bins) instead
14//!   of O(n log n).
15//! - **Subtraction trick**: the child histogram with more samples is computed by
16//!   subtracting the smaller child's histogram from the parent, halving the work.
17//! - **Missing value support**: NaN values are routed to the bin that yields the
18//!   best split, enabling native handling of missing data.
19//!
20//! # Regression Losses
21//!
22//! - **`LeastSquares`**: mean squared error; gradient = `y - F(x)`, hessian = 1.
23//! - **`LeastAbsoluteDeviation`**: gradient = `sign(y - F(x))`, hessian = 1.
24//!
25//! # Classification Loss
26//!
27//! - **`LogLoss`**: binary and multiclass logistic loss (one-vs-rest via softmax).
28//!
29//! # Examples
30//!
31//! ```
32//! use ferrolearn_tree::HistGradientBoostingRegressor;
33//! use ferrolearn_core::{Fit, Predict};
34//! use ndarray::{array, Array1, Array2};
35//!
36//! let x = Array2::from_shape_vec((8, 1), vec![
37//!     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
38//! ]).unwrap();
39//! let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
40//!
41//! let model = HistGradientBoostingRegressor::<f64>::new()
42//!     .with_n_estimators(50)
43//!     .with_learning_rate(0.1)
44//!     .with_random_state(42);
45//! let fitted = model.fit(&x, &y).unwrap();
46//! let preds = fitted.predict(&x).unwrap();
47//! assert_eq!(preds.len(), 8);
48//! ```
49
50use ferrolearn_core::error::FerroError;
51use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
52use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
53use ferrolearn_core::traits::{Fit, Predict};
54use ndarray::{Array1, Array2};
55use num_traits::{Float, FromPrimitive, ToPrimitive};
56use serde::{Deserialize, Serialize};
57
58// ---------------------------------------------------------------------------
59// Loss enums
60// ---------------------------------------------------------------------------
61
62/// Loss function for histogram-based gradient boosting regression.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum HistRegressionLoss {
65    /// Least squares (L2) loss.
66    LeastSquares,
67    /// Least absolute deviation (L1) loss.
68    LeastAbsoluteDeviation,
69}
70
71/// Loss function for histogram-based gradient boosting classification.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum HistClassificationLoss {
74    /// Log-loss (logistic / cross-entropy) for binary and multiclass.
75    LogLoss,
76}
77
78// ---------------------------------------------------------------------------
79// Internal: histogram tree node
80// ---------------------------------------------------------------------------
81
82/// A node in a histogram-based gradient boosting tree.
83///
84/// Unlike [`crate::decision_tree::Node`], thresholds are stored as bin indices
85/// rather than raw feature values. Prediction at a split checks whether the
86/// sample's bin index for the feature is `<= threshold_bin`.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum HistNode<F> {
89    /// An internal split node.
90    Split {
91        /// Feature index used for the split.
92        feature: usize,
93        /// Bin index threshold; samples with `bin[feature] <= threshold_bin`
94        /// go left.
95        threshold_bin: u16,
96        /// Whether NaN values should go left (`true`) or right (`false`).
97        nan_goes_left: bool,
98        /// Index of the left child node in the flat vec.
99        left: usize,
100        /// Index of the right child node in the flat vec.
101        right: usize,
102        /// Weighted gain from this split (for feature importance).
103        gain: F,
104        /// Number of samples that reached this node during training.
105        n_samples: usize,
106    },
107    /// A leaf node that stores a prediction value.
108    Leaf {
109        /// Predicted value (raw leaf output, e.g. gradient step).
110        value: F,
111        /// Number of samples that reached this node during training.
112        n_samples: usize,
113    },
114}
115
116// ---------------------------------------------------------------------------
117// Internal: binning
118// ---------------------------------------------------------------------------
119
120/// The special bin index reserved for NaN / missing values.
121const NAN_BIN: u16 = u16::MAX;
122
123/// Bin edges for a single feature, plus the number of non-NaN bins.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125struct FeatureBinInfo<F> {
126    /// Sorted bin edges (upper thresholds). `edges[b]` is the upper bound
127    /// for bin `b`. The last bin captures everything above `edges[len-2]`.
128    edges: Vec<F>,
129    /// Number of non-NaN bins for this feature (at most `max_bins`).
130    n_bins: u16,
131    /// Whether any NaN was observed for this feature during binning.
132    has_nan: bool,
133}
134
135/// Compute quantile-based bin edges for every feature.
136fn compute_bin_edges<F: Float>(x: &Array2<F>, max_bins: u16) -> Vec<FeatureBinInfo<F>> {
137    let n_features = x.ncols();
138    let n_samples = x.nrows();
139    let mut infos = Vec::with_capacity(n_features);
140
141    for j in 0..n_features {
142        let col = x.column(j);
143        // Separate non-NaN values and sort them.
144        let mut vals: Vec<F> = col.iter().copied().filter(|v| !v.is_nan()).collect();
145        vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
146
147        let has_nan = vals.len() < n_samples;
148
149        if vals.is_empty() {
150            // All NaN.
151            infos.push(FeatureBinInfo {
152                edges: Vec::new(),
153                n_bins: 0,
154                has_nan: true,
155            });
156            continue;
157        }
158
159        // Deduplicate to find unique values.
160        let mut unique: Vec<F> = Vec::new();
161        for &v in &vals {
162            if unique.is_empty() || (v - *unique.last().unwrap()).abs() > F::epsilon() {
163                unique.push(v);
164            }
165        }
166
167        let n_unique = unique.len();
168        let actual_bins = (max_bins as usize).min(n_unique);
169
170        if actual_bins <= 1 {
171            // Only one unique value — one bin, edge is that value.
172            infos.push(FeatureBinInfo {
173                edges: vec![unique[0]],
174                n_bins: 1,
175                has_nan,
176            });
177            continue;
178        }
179
180        // Compute quantile-based bin edges.
181        let mut edges = Vec::with_capacity(actual_bins);
182        for b in 1..actual_bins {
183            let frac = b as f64 / actual_bins as f64;
184            let idx_f = frac * (n_unique as f64 - 1.0);
185            let lo = idx_f.floor() as usize;
186            let hi = (lo + 1).min(n_unique - 1);
187            let t = F::from(idx_f - lo as f64).unwrap();
188            let edge = unique[lo] * (F::one() - t) + unique[hi] * t;
189            // Avoid duplicate edges.
190            if edges.is_empty() || (edge - *edges.last().unwrap()).abs() > F::epsilon() {
191                edges.push(edge);
192            }
193        }
194        // Add a final edge for the upper bound (max value).
195        let last = *unique.last().unwrap();
196        if edges.is_empty() || (last - *edges.last().unwrap()).abs() > F::epsilon() {
197            edges.push(last);
198        }
199
200        let n_bins = edges.len() as u16;
201        infos.push(FeatureBinInfo {
202            edges,
203            n_bins,
204            has_nan,
205        });
206    }
207
208    infos
209}
210
211/// Map a single feature value to its bin index given bin edges.
212#[inline]
213fn map_to_bin<F: Float>(value: F, info: &FeatureBinInfo<F>) -> u16 {
214    if value.is_nan() {
215        return NAN_BIN;
216    }
217    if info.n_bins == 0 {
218        return NAN_BIN;
219    }
220    // Binary search: find the first edge >= value.
221    let edges = &info.edges;
222    let mut lo: usize = 0;
223    let mut hi: usize = edges.len();
224    while lo < hi {
225        let mid = lo + (hi - lo) / 2;
226        if edges[mid] < value {
227            lo = mid + 1;
228        } else {
229            hi = mid;
230        }
231    }
232    if lo >= edges.len() {
233        (info.n_bins - 1).min(edges.len() as u16 - 1)
234    } else {
235        lo as u16
236    }
237}
238
239/// Bin all samples for all features.
240fn bin_data<F: Float>(x: &Array2<F>, bin_infos: &[FeatureBinInfo<F>]) -> Vec<Vec<u16>> {
241    let (n_samples, n_features) = x.dim();
242    // bins[i][j] = bin index for sample i, feature j.
243    let mut bins = vec![vec![0u16; n_features]; n_samples];
244    for j in 0..n_features {
245        for i in 0..n_samples {
246            bins[i][j] = map_to_bin(x[[i, j]], &bin_infos[j]);
247        }
248    }
249    bins
250}
251
252// ---------------------------------------------------------------------------
253// Internal: histogram-based tree building
254// ---------------------------------------------------------------------------
255
256/// Gradient/hessian accumulator for a single bin.
257#[derive(Debug, Clone, Copy, Default)]
258struct BinEntry<F> {
259    grad_sum: F,
260    hess_sum: F,
261    count: usize,
262}
263
264/// A histogram for one feature: one entry per bin, plus an optional NaN entry.
265struct FeatureHistogram<F> {
266    bins: Vec<BinEntry<F>>,
267    nan_entry: BinEntry<F>,
268}
269
270impl<F: Float> FeatureHistogram<F> {
271    fn new(n_bins: u16) -> Self {
272        Self {
273            bins: vec![
274                BinEntry {
275                    grad_sum: F::zero(),
276                    hess_sum: F::zero(),
277                    count: 0,
278                };
279                n_bins as usize
280            ],
281            nan_entry: BinEntry {
282                grad_sum: F::zero(),
283                hess_sum: F::zero(),
284                count: 0,
285            },
286        }
287    }
288}
289
290/// Build histograms for all features from scratch over the given sample indices.
291fn build_histograms<F: Float>(
292    bin_data: &[Vec<u16>],
293    gradients: &[F],
294    hessians: &[F],
295    sample_indices: &[usize],
296    bin_infos: &[FeatureBinInfo<F>],
297) -> Vec<FeatureHistogram<F>> {
298    let n_features = bin_infos.len();
299    let mut histograms: Vec<FeatureHistogram<F>> = bin_infos
300        .iter()
301        .map(|info| FeatureHistogram::new(info.n_bins))
302        .collect();
303
304    for &i in sample_indices {
305        let g = gradients[i];
306        let h = hessians[i];
307        for j in 0..n_features {
308            let b = bin_data[i][j];
309            if b == NAN_BIN {
310                histograms[j].nan_entry.grad_sum = histograms[j].nan_entry.grad_sum + g;
311                histograms[j].nan_entry.hess_sum = histograms[j].nan_entry.hess_sum + h;
312                histograms[j].nan_entry.count += 1;
313            } else {
314                let entry = &mut histograms[j].bins[b as usize];
315                entry.grad_sum = entry.grad_sum + g;
316                entry.hess_sum = entry.hess_sum + h;
317                entry.count += 1;
318            }
319        }
320    }
321
322    histograms
323}
324
325/// Subtraction trick: compute child histogram from parent and sibling.
326fn subtract_histograms<F: Float>(
327    parent: &[FeatureHistogram<F>],
328    sibling: &[FeatureHistogram<F>],
329    bin_infos: &[FeatureBinInfo<F>],
330) -> Vec<FeatureHistogram<F>> {
331    let n_features = bin_infos.len();
332    let mut result: Vec<FeatureHistogram<F>> = bin_infos
333        .iter()
334        .map(|info| FeatureHistogram::new(info.n_bins))
335        .collect();
336
337    for j in 0..n_features {
338        let n_bins = bin_infos[j].n_bins as usize;
339        for b in 0..n_bins {
340            result[j].bins[b].grad_sum = parent[j].bins[b].grad_sum - sibling[j].bins[b].grad_sum;
341            result[j].bins[b].hess_sum = parent[j].bins[b].hess_sum - sibling[j].bins[b].hess_sum;
342            // count is usize so we need saturating_sub for safety
343            result[j].bins[b].count = parent[j].bins[b]
344                .count
345                .saturating_sub(sibling[j].bins[b].count);
346        }
347        // NaN bin
348        result[j].nan_entry.grad_sum = parent[j].nan_entry.grad_sum - sibling[j].nan_entry.grad_sum;
349        result[j].nan_entry.hess_sum = parent[j].nan_entry.hess_sum - sibling[j].nan_entry.hess_sum;
350        result[j].nan_entry.count = parent[j]
351            .nan_entry
352            .count
353            .saturating_sub(sibling[j].nan_entry.count);
354    }
355
356    result
357}
358
359/// Result of finding the best split for a node.
360struct SplitCandidate<F> {
361    feature: usize,
362    threshold_bin: u16,
363    gain: F,
364    nan_goes_left: bool,
365}
366
367/// Find the best split across all features from histograms.
368///
369/// Uses the standard gain formula:
370///   gain = (G_L^2 / (H_L + lambda)) + (G_R^2 / (H_R + lambda)) - (G_parent^2 / (H_parent + lambda))
371fn find_best_split_from_histograms<F: Float>(
372    histograms: &[FeatureHistogram<F>],
373    bin_infos: &[FeatureBinInfo<F>],
374    total_grad: F,
375    total_hess: F,
376    total_count: usize,
377    l2_regularization: F,
378    min_samples_leaf: usize,
379) -> Option<SplitCandidate<F>> {
380    let n_features = bin_infos.len();
381    let parent_gain = total_grad * total_grad / (total_hess + l2_regularization);
382
383    let mut best: Option<SplitCandidate<F>> = None;
384
385    for j in 0..n_features {
386        let n_bins = bin_infos[j].n_bins as usize;
387        if n_bins <= 1 {
388            continue;
389        }
390        let nan = &histograms[j].nan_entry;
391
392        // Try scanning left-to-right through bins.
393        // For each split position b, left contains bins 0..=b, right contains b+1..n_bins-1.
394        // NaN samples can go either left or right; we try both.
395        let mut left_grad = F::zero();
396        let mut left_hess = F::zero();
397        let mut left_count: usize = 0;
398
399        for b in 0..(n_bins - 1) {
400            let entry = &histograms[j].bins[b];
401            left_grad = left_grad + entry.grad_sum;
402            left_hess = left_hess + entry.hess_sum;
403            left_count += entry.count;
404
405            let right_grad_no_nan = total_grad - left_grad - nan.grad_sum;
406            let right_hess_no_nan = total_hess - left_hess - nan.hess_sum;
407            let right_count_no_nan = total_count
408                .saturating_sub(left_count)
409                .saturating_sub(nan.count);
410
411            // Try NaN goes left.
412            {
413                let lg = left_grad + nan.grad_sum;
414                let lh = left_hess + nan.hess_sum;
415                let lc = left_count + nan.count;
416                let rg = right_grad_no_nan;
417                let rh = right_hess_no_nan;
418                let rc = right_count_no_nan;
419
420                if lc >= min_samples_leaf && rc >= min_samples_leaf {
421                    let gain = lg * lg / (lh + l2_regularization)
422                        + rg * rg / (rh + l2_regularization)
423                        - parent_gain;
424                    if gain > F::zero() {
425                        let better = match &best {
426                            None => true,
427                            Some(curr) => gain > curr.gain,
428                        };
429                        if better {
430                            best = Some(SplitCandidate {
431                                feature: j,
432                                threshold_bin: b as u16,
433                                gain,
434                                nan_goes_left: true,
435                            });
436                        }
437                    }
438                }
439            }
440
441            // Try NaN goes right.
442            {
443                let lg = left_grad;
444                let lh = left_hess;
445                let lc = left_count;
446                let rg = right_grad_no_nan + nan.grad_sum;
447                let rh = right_hess_no_nan + nan.hess_sum;
448                let rc = right_count_no_nan + nan.count;
449
450                if lc >= min_samples_leaf && rc >= min_samples_leaf {
451                    let gain = lg * lg / (lh + l2_regularization)
452                        + rg * rg / (rh + l2_regularization)
453                        - parent_gain;
454                    if gain > F::zero() {
455                        let better = match &best {
456                            None => true,
457                            Some(curr) => gain > curr.gain,
458                        };
459                        if better {
460                            best = Some(SplitCandidate {
461                                feature: j,
462                                threshold_bin: b as u16,
463                                gain,
464                                nan_goes_left: false,
465                            });
466                        }
467                    }
468                }
469            }
470        }
471    }
472
473    best
474}
475
476/// Compute leaf value: -G / (H + lambda).
477#[inline]
478fn compute_leaf_value<F: Float>(grad_sum: F, hess_sum: F, l2_reg: F) -> F {
479    if hess_sum.abs() < F::epsilon() {
480        F::zero()
481    } else {
482        -grad_sum / (hess_sum + l2_reg)
483    }
484}
485
486/// Parameters for histogram tree building.
487struct HistTreeParams<F> {
488    max_depth: Option<usize>,
489    min_samples_leaf: usize,
490    max_leaf_nodes: Option<usize>,
491    l2_regularization: F,
492}
493
494/// Build a single histogram-based regression tree.
495///
496/// Returns a vector of `HistNode`s (flat representation).
497fn build_hist_tree<F: Float>(
498    binned: &[Vec<u16>],
499    gradients: &[F],
500    hessians: &[F],
501    sample_indices: &[usize],
502    bin_infos: &[FeatureBinInfo<F>],
503    params: &HistTreeParams<F>,
504) -> Vec<HistNode<F>> {
505    let n_features = bin_infos.len();
506    let _ = n_features; // used implicitly through bin_infos
507
508    if sample_indices.is_empty() {
509        return vec![HistNode::Leaf {
510            value: F::zero(),
511            n_samples: 0,
512        }];
513    }
514
515    // Use a work queue approach for best-first (leaf-wise) or depth-first growth.
516    // If max_leaf_nodes is set, use best-first; otherwise depth-first.
517    if params.max_leaf_nodes.is_some() {
518        build_hist_tree_best_first(
519            binned,
520            gradients,
521            hessians,
522            sample_indices,
523            bin_infos,
524            params,
525        )
526    } else {
527        let mut nodes = Vec::new();
528        let hist = build_histograms(binned, gradients, hessians, sample_indices, bin_infos);
529        build_hist_tree_recursive(
530            binned,
531            gradients,
532            hessians,
533            sample_indices,
534            bin_infos,
535            params,
536            &hist,
537            0,
538            &mut nodes,
539        );
540        nodes
541    }
542}
543
544/// Recursive depth-first histogram tree building.
545#[allow(clippy::too_many_arguments)]
546fn build_hist_tree_recursive<F: Float>(
547    binned: &[Vec<u16>],
548    gradients: &[F],
549    hessians: &[F],
550    sample_indices: &[usize],
551    bin_infos: &[FeatureBinInfo<F>],
552    params: &HistTreeParams<F>,
553    histograms: &[FeatureHistogram<F>],
554    depth: usize,
555    nodes: &mut Vec<HistNode<F>>,
556) -> usize {
557    let n = sample_indices.len();
558    let grad_sum: F = sample_indices
559        .iter()
560        .map(|&i| gradients[i])
561        .fold(F::zero(), |a, b| a + b);
562    let hess_sum: F = sample_indices
563        .iter()
564        .map(|&i| hessians[i])
565        .fold(F::zero(), |a, b| a + b);
566
567    // Check stopping conditions.
568    let at_max_depth = params.max_depth.is_some_and(|d| depth >= d);
569    let too_few = n < 2 * params.min_samples_leaf;
570
571    if at_max_depth || too_few {
572        let idx = nodes.len();
573        nodes.push(HistNode::Leaf {
574            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
575            n_samples: n,
576        });
577        return idx;
578    }
579
580    // Find best split.
581    let split = find_best_split_from_histograms(
582        histograms,
583        bin_infos,
584        grad_sum,
585        hess_sum,
586        n,
587        params.l2_regularization,
588        params.min_samples_leaf,
589    );
590
591    let split = if let Some(s) = split {
592        s
593    } else {
594        let idx = nodes.len();
595        nodes.push(HistNode::Leaf {
596            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
597            n_samples: n,
598        });
599        return idx;
600    };
601
602    // Partition samples into left and right.
603    let (left_indices, right_indices): (Vec<usize>, Vec<usize>) =
604        sample_indices.iter().partition(|&&i| {
605            let b = binned[i][split.feature];
606            if b == NAN_BIN {
607                split.nan_goes_left
608            } else {
609                b <= split.threshold_bin
610            }
611        });
612
613    if left_indices.is_empty() || right_indices.is_empty() {
614        let idx = nodes.len();
615        nodes.push(HistNode::Leaf {
616            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
617            n_samples: n,
618        });
619        return idx;
620    }
621
622    // Build histograms for the smaller child, then use subtraction trick.
623    let (small_indices, _large_indices, small_is_left) =
624        if left_indices.len() <= right_indices.len() {
625            (&left_indices, &right_indices, true)
626        } else {
627            (&right_indices, &left_indices, false)
628        };
629
630    let small_hist = build_histograms(binned, gradients, hessians, small_indices, bin_infos);
631    let large_hist = subtract_histograms(histograms, &small_hist, bin_infos);
632
633    let (left_hist, right_hist) = if small_is_left {
634        (small_hist, large_hist)
635    } else {
636        (large_hist, small_hist)
637    };
638
639    // Reserve a placeholder for this split node.
640    let node_idx = nodes.len();
641    nodes.push(HistNode::Leaf {
642        value: F::zero(),
643        n_samples: 0,
644    }); // placeholder
645
646    // Recurse.
647    let left_idx = build_hist_tree_recursive(
648        binned,
649        gradients,
650        hessians,
651        &left_indices,
652        bin_infos,
653        params,
654        &left_hist,
655        depth + 1,
656        nodes,
657    );
658    let right_idx = build_hist_tree_recursive(
659        binned,
660        gradients,
661        hessians,
662        &right_indices,
663        bin_infos,
664        params,
665        &right_hist,
666        depth + 1,
667        nodes,
668    );
669
670    nodes[node_idx] = HistNode::Split {
671        feature: split.feature,
672        threshold_bin: split.threshold_bin,
673        nan_goes_left: split.nan_goes_left,
674        left: left_idx,
675        right: right_idx,
676        gain: split.gain,
677        n_samples: n,
678    };
679
680    node_idx
681}
682
683/// Entry in the best-first priority queue.
684struct SplitTask {
685    /// Indices of samples at this node.
686    sample_indices: Vec<usize>,
687    /// The node index in the flat node vec (a leaf placeholder).
688    node_idx: usize,
689    /// Depth of this node.
690    depth: usize,
691    /// The gain of the best split at this node.
692    gain: f64,
693    /// Feature of the best split.
694    feature: usize,
695    /// Bin threshold of the best split.
696    threshold_bin: u16,
697    /// Whether NaN goes left.
698    nan_goes_left: bool,
699}
700
701/// Build a histogram tree using best-first (leaf-wise) growth with max_leaf_nodes.
702fn build_hist_tree_best_first<F: Float>(
703    binned: &[Vec<u16>],
704    gradients: &[F],
705    hessians: &[F],
706    sample_indices: &[usize],
707    bin_infos: &[FeatureBinInfo<F>],
708    params: &HistTreeParams<F>,
709) -> Vec<HistNode<F>> {
710    let max_leaves = params.max_leaf_nodes.unwrap_or(usize::MAX);
711    let mut nodes: Vec<HistNode<F>> = Vec::new();
712
713    // Root node.
714    let n = sample_indices.len();
715    let grad_sum: F = sample_indices
716        .iter()
717        .map(|&i| gradients[i])
718        .fold(F::zero(), |a, b| a + b);
719    let hess_sum: F = sample_indices
720        .iter()
721        .map(|&i| hessians[i])
722        .fold(F::zero(), |a, b| a + b);
723
724    let root_idx = nodes.len();
725    nodes.push(HistNode::Leaf {
726        value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
727        n_samples: n,
728    });
729
730    let root_hist = build_histograms(binned, gradients, hessians, sample_indices, bin_infos);
731    let root_split = find_best_split_from_histograms(
732        &root_hist,
733        bin_infos,
734        grad_sum,
735        hess_sum,
736        n,
737        params.l2_regularization,
738        params.min_samples_leaf,
739    );
740
741    let mut pending: Vec<(SplitTask, Vec<FeatureHistogram<F>>)> = Vec::new();
742    let mut n_leaves: usize = 1;
743
744    if let Some(split) = root_split {
745        let at_max_depth = params.max_depth.is_some_and(|d| d == 0);
746        if !at_max_depth {
747            pending.push((
748                SplitTask {
749                    sample_indices: sample_indices.to_vec(),
750                    node_idx: root_idx,
751                    depth: 0,
752                    gain: split.gain.to_f64().unwrap_or(0.0),
753                    feature: split.feature,
754                    threshold_bin: split.threshold_bin,
755                    nan_goes_left: split.nan_goes_left,
756                },
757                root_hist,
758            ));
759        }
760    }
761
762    while !pending.is_empty() && n_leaves < max_leaves {
763        // Pick the task with the highest gain.
764        let best_idx = pending
765            .iter()
766            .enumerate()
767            .max_by(|(_, a), (_, b)| {
768                a.0.gain
769                    .partial_cmp(&b.0.gain)
770                    .unwrap_or(std::cmp::Ordering::Equal)
771            })
772            .map(|(i, _)| i)
773            .unwrap();
774
775        let (task, parent_hist) = pending.swap_remove(best_idx);
776
777        // Partition.
778        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) =
779            task.sample_indices.iter().partition(|&&i| {
780                let b = binned[i][task.feature];
781                if b == NAN_BIN {
782                    task.nan_goes_left
783                } else {
784                    b <= task.threshold_bin
785                }
786            });
787
788        if left_indices.is_empty() || right_indices.is_empty() {
789            continue;
790        }
791
792        // Build histograms using subtraction trick.
793        let (small_indices, _large_indices, small_is_left) =
794            if left_indices.len() <= right_indices.len() {
795                (&left_indices, &right_indices, true)
796            } else {
797                (&right_indices, &left_indices, false)
798            };
799
800        let small_hist = build_histograms(binned, gradients, hessians, small_indices, bin_infos);
801        let large_hist = subtract_histograms(&parent_hist, &small_hist, bin_infos);
802
803        let (left_hist, right_hist) = if small_is_left {
804            (small_hist, large_hist)
805        } else {
806            (large_hist, small_hist)
807        };
808
809        // Create left and right leaf nodes.
810        let left_grad: F = left_indices
811            .iter()
812            .map(|&i| gradients[i])
813            .fold(F::zero(), |a, b| a + b);
814        let left_hess: F = left_indices
815            .iter()
816            .map(|&i| hessians[i])
817            .fold(F::zero(), |a, b| a + b);
818        let right_grad: F = right_indices
819            .iter()
820            .map(|&i| gradients[i])
821            .fold(F::zero(), |a, b| a + b);
822        let right_hess: F = right_indices
823            .iter()
824            .map(|&i| hessians[i])
825            .fold(F::zero(), |a, b| a + b);
826
827        let left_idx = nodes.len();
828        nodes.push(HistNode::Leaf {
829            value: compute_leaf_value(left_grad, left_hess, params.l2_regularization),
830            n_samples: left_indices.len(),
831        });
832        let right_idx = nodes.len();
833        nodes.push(HistNode::Leaf {
834            value: compute_leaf_value(right_grad, right_hess, params.l2_regularization),
835            n_samples: right_indices.len(),
836        });
837
838        // Convert the parent leaf placeholder into a split node.
839        nodes[task.node_idx] = HistNode::Split {
840            feature: task.feature,
841            threshold_bin: task.threshold_bin,
842            nan_goes_left: task.nan_goes_left,
843            left: left_idx,
844            right: right_idx,
845            gain: F::from(task.gain).unwrap(),
846            n_samples: task.sample_indices.len(),
847        };
848
849        // One leaf became two, so net +1 leaf.
850        n_leaves += 1;
851
852        let child_depth = task.depth + 1;
853        let at_max_depth = params.max_depth.is_some_and(|d| child_depth >= d);
854
855        if !at_max_depth && n_leaves < max_leaves {
856            // Try to split left child.
857            if left_indices.len() >= 2 * params.min_samples_leaf {
858                let left_split = find_best_split_from_histograms(
859                    &left_hist,
860                    bin_infos,
861                    left_grad,
862                    left_hess,
863                    left_indices.len(),
864                    params.l2_regularization,
865                    params.min_samples_leaf,
866                );
867                if let Some(s) = left_split {
868                    pending.push((
869                        SplitTask {
870                            sample_indices: left_indices,
871                            node_idx: left_idx,
872                            depth: child_depth,
873                            gain: s.gain.to_f64().unwrap_or(0.0),
874                            feature: s.feature,
875                            threshold_bin: s.threshold_bin,
876                            nan_goes_left: s.nan_goes_left,
877                        },
878                        left_hist,
879                    ));
880                }
881            }
882
883            // Try to split right child.
884            if right_indices.len() >= 2 * params.min_samples_leaf {
885                let right_split = find_best_split_from_histograms(
886                    &right_hist,
887                    bin_infos,
888                    right_grad,
889                    right_hess,
890                    right_indices.len(),
891                    params.l2_regularization,
892                    params.min_samples_leaf,
893                );
894                if let Some(s) = right_split {
895                    pending.push((
896                        SplitTask {
897                            sample_indices: right_indices,
898                            node_idx: right_idx,
899                            depth: child_depth,
900                            gain: s.gain.to_f64().unwrap_or(0.0),
901                            feature: s.feature,
902                            threshold_bin: s.threshold_bin,
903                            nan_goes_left: s.nan_goes_left,
904                        },
905                        right_hist,
906                    ));
907                }
908            }
909        }
910    }
911
912    nodes
913}
914
915/// Traverse a histogram tree to find the leaf for a single binned sample.
916#[inline]
917fn traverse_hist_tree<F: Float>(nodes: &[HistNode<F>], sample_bins: &[u16]) -> usize {
918    let mut idx = 0;
919    loop {
920        match &nodes[idx] {
921            HistNode::Split {
922                feature,
923                threshold_bin,
924                nan_goes_left,
925                left,
926                right,
927                ..
928            } => {
929                let b = sample_bins[*feature];
930                if b == NAN_BIN {
931                    idx = if *nan_goes_left { *left } else { *right };
932                } else if b <= *threshold_bin {
933                    idx = *left;
934                } else {
935                    idx = *right;
936                }
937            }
938            HistNode::Leaf { .. } => return idx,
939        }
940    }
941}
942
943/// Compute feature importances from a histogram tree's gain values.
944fn compute_hist_feature_importances<F: Float>(
945    nodes: &[HistNode<F>],
946    n_features: usize,
947) -> Array1<F> {
948    let mut importances = Array1::zeros(n_features);
949    for node in nodes {
950        if let HistNode::Split { feature, gain, .. } = node {
951            importances[*feature] = importances[*feature] + *gain;
952        }
953    }
954    importances
955}
956
957// ---------------------------------------------------------------------------
958// Internal helpers
959// ---------------------------------------------------------------------------
960
961/// Sigmoid function: 1 / (1 + exp(-x)).
962fn sigmoid<F: Float>(x: F) -> F {
963    F::one() / (F::one() + (-x).exp())
964}
965
966/// Compute softmax probabilities for each class across all samples.
967///
968/// Returns `probs[k][i]` = probability of class k for sample i.
969fn softmax_matrix<F: Float>(
970    f_vals: &[Array1<F>],
971    n_samples: usize,
972    n_classes: usize,
973) -> Vec<Vec<F>> {
974    let mut probs: Vec<Vec<F>> = vec![vec![F::zero(); n_samples]; n_classes];
975    for i in 0..n_samples {
976        let max_val = (0..n_classes)
977            .map(|k| f_vals[k][i])
978            .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
979        let mut sum = F::zero();
980        let mut exps = vec![F::zero(); n_classes];
981        for k in 0..n_classes {
982            exps[k] = (f_vals[k][i] - max_val).exp();
983            sum = sum + exps[k];
984        }
985        let eps = F::from(1e-15).unwrap();
986        if sum < eps {
987            sum = eps;
988        }
989        for k in 0..n_classes {
990            probs[k][i] = exps[k] / sum;
991        }
992    }
993    probs
994}
995
996// ---------------------------------------------------------------------------
997// HistGradientBoostingRegressor
998// ---------------------------------------------------------------------------
999
1000/// Histogram-based gradient boosting regressor.
1001///
1002/// Uses quantile-based feature binning and gradient/hessian histograms for
1003/// O(n_bins) split finding per node. This is significantly faster than the
1004/// standard [`GradientBoostingRegressor`](crate::GradientBoostingRegressor)
1005/// for larger datasets.
1006///
1007/// # Type Parameters
1008///
1009/// - `F`: The floating-point type (`f32` or `f64`).
1010#[derive(Debug, Clone, Serialize, Deserialize)]
1011pub struct HistGradientBoostingRegressor<F> {
1012    /// Number of boosting stages (trees).
1013    pub n_estimators: usize,
1014    /// Learning rate (shrinkage) applied to each tree's contribution.
1015    pub learning_rate: f64,
1016    /// Maximum depth of each tree.
1017    pub max_depth: Option<usize>,
1018    /// Minimum number of samples required in a leaf node.
1019    pub min_samples_leaf: usize,
1020    /// Maximum number of bins for feature discretisation (at most 256).
1021    pub max_bins: u16,
1022    /// L2 regularization term on weights.
1023    pub l2_regularization: f64,
1024    /// Maximum number of leaf nodes per tree (best-first growth).
1025    /// If `None`, depth-first growth is used with `max_depth`.
1026    pub max_leaf_nodes: Option<usize>,
1027    /// Loss function.
1028    pub loss: HistRegressionLoss,
1029    /// Random seed for reproducibility.
1030    pub random_state: Option<u64>,
1031    _marker: std::marker::PhantomData<F>,
1032}
1033
1034impl<F: Float> HistGradientBoostingRegressor<F> {
1035    /// Create a new `HistGradientBoostingRegressor` with default settings.
1036    ///
1037    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
1038    /// `max_depth = None`, `min_samples_leaf = 20`,
1039    /// `max_bins = 255`, `l2_regularization = 0.0`,
1040    /// `max_leaf_nodes = Some(31)`, `loss = LeastSquares`.
1041    #[must_use]
1042    pub fn new() -> Self {
1043        Self {
1044            n_estimators: 100,
1045            learning_rate: 0.1,
1046            max_depth: None,
1047            min_samples_leaf: 20,
1048            max_bins: 255,
1049            l2_regularization: 0.0,
1050            max_leaf_nodes: Some(31),
1051            loss: HistRegressionLoss::LeastSquares,
1052            random_state: None,
1053            _marker: std::marker::PhantomData,
1054        }
1055    }
1056
1057    /// Set the number of boosting stages.
1058    #[must_use]
1059    pub fn with_n_estimators(mut self, n: usize) -> Self {
1060        self.n_estimators = n;
1061        self
1062    }
1063
1064    /// Set the learning rate (shrinkage).
1065    #[must_use]
1066    pub fn with_learning_rate(mut self, lr: f64) -> Self {
1067        self.learning_rate = lr;
1068        self
1069    }
1070
1071    /// Set the maximum tree depth.
1072    #[must_use]
1073    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
1074        self.max_depth = d;
1075        self
1076    }
1077
1078    /// Set the minimum number of samples in a leaf.
1079    #[must_use]
1080    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
1081        self.min_samples_leaf = n;
1082        self
1083    }
1084
1085    /// Set the maximum number of bins for feature discretisation.
1086    #[must_use]
1087    pub fn with_max_bins(mut self, bins: u16) -> Self {
1088        self.max_bins = bins;
1089        self
1090    }
1091
1092    /// Set the L2 regularization term.
1093    #[must_use]
1094    pub fn with_l2_regularization(mut self, reg: f64) -> Self {
1095        self.l2_regularization = reg;
1096        self
1097    }
1098
1099    /// Set the maximum number of leaf nodes (best-first growth).
1100    #[must_use]
1101    pub fn with_max_leaf_nodes(mut self, n: Option<usize>) -> Self {
1102        self.max_leaf_nodes = n;
1103        self
1104    }
1105
1106    /// Set the loss function.
1107    #[must_use]
1108    pub fn with_loss(mut self, loss: HistRegressionLoss) -> Self {
1109        self.loss = loss;
1110        self
1111    }
1112
1113    /// Set the random seed for reproducibility.
1114    #[must_use]
1115    pub fn with_random_state(mut self, seed: u64) -> Self {
1116        self.random_state = Some(seed);
1117        self
1118    }
1119}
1120
1121impl<F: Float> Default for HistGradientBoostingRegressor<F> {
1122    fn default() -> Self {
1123        Self::new()
1124    }
1125}
1126
1127// ---------------------------------------------------------------------------
1128// FittedHistGradientBoostingRegressor
1129// ---------------------------------------------------------------------------
1130
1131/// A fitted histogram-based gradient boosting regressor.
1132///
1133/// Stores the binning information, initial prediction, and the sequence of
1134/// fitted histogram trees. Predictions are computed by binning the input
1135/// features and traversing each tree.
1136#[derive(Debug, Clone)]
1137pub struct FittedHistGradientBoostingRegressor<F> {
1138    /// Bin edge information for each feature.
1139    bin_infos: Vec<FeatureBinInfo<F>>,
1140    /// Initial prediction (baseline).
1141    init: F,
1142    /// Learning rate used during training.
1143    learning_rate: F,
1144    /// Sequence of fitted histogram trees.
1145    trees: Vec<Vec<HistNode<F>>>,
1146    /// Number of features.
1147    n_features: usize,
1148    /// Per-feature importance scores (normalised).
1149    feature_importances: Array1<F>,
1150}
1151
1152impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>>
1153    for HistGradientBoostingRegressor<F>
1154{
1155    type Fitted = FittedHistGradientBoostingRegressor<F>;
1156    type Error = FerroError;
1157
1158    /// Fit the histogram-based gradient boosting regressor.
1159    ///
1160    /// # Errors
1161    ///
1162    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
1163    /// numbers of samples.
1164    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
1165    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
1166    fn fit(
1167        &self,
1168        x: &Array2<F>,
1169        y: &Array1<F>,
1170    ) -> Result<FittedHistGradientBoostingRegressor<F>, FerroError> {
1171        let (n_samples, n_features) = x.dim();
1172
1173        // Validate inputs.
1174        if n_samples != y.len() {
1175            return Err(FerroError::ShapeMismatch {
1176                expected: vec![n_samples],
1177                actual: vec![y.len()],
1178                context: "y length must match number of samples in X".into(),
1179            });
1180        }
1181        if n_samples == 0 {
1182            return Err(FerroError::InsufficientSamples {
1183                required: 1,
1184                actual: 0,
1185                context: "HistGradientBoostingRegressor requires at least one sample".into(),
1186            });
1187        }
1188        if self.n_estimators == 0 {
1189            return Err(FerroError::InvalidParameter {
1190                name: "n_estimators".into(),
1191                reason: "must be at least 1".into(),
1192            });
1193        }
1194        if self.learning_rate <= 0.0 {
1195            return Err(FerroError::InvalidParameter {
1196                name: "learning_rate".into(),
1197                reason: "must be positive".into(),
1198            });
1199        }
1200        if self.max_bins < 2 {
1201            return Err(FerroError::InvalidParameter {
1202                name: "max_bins".into(),
1203                reason: "must be at least 2".into(),
1204            });
1205        }
1206
1207        let lr = F::from(self.learning_rate).unwrap();
1208        let l2_reg = F::from(self.l2_regularization).unwrap();
1209
1210        // Compute bin edges and bin the data.
1211        let bin_infos = compute_bin_edges(x, self.max_bins);
1212        let binned = bin_data(x, &bin_infos);
1213
1214        // Initial prediction.
1215        let init = match self.loss {
1216            HistRegressionLoss::LeastSquares => {
1217                let sum: F = y.iter().copied().fold(F::zero(), |a, b| a + b);
1218                sum / F::from(n_samples).unwrap()
1219            }
1220            HistRegressionLoss::LeastAbsoluteDeviation => median_f(y),
1221        };
1222
1223        let mut f_vals = Array1::from_elem(n_samples, init);
1224        let all_indices: Vec<usize> = (0..n_samples).collect();
1225
1226        let tree_params = HistTreeParams {
1227            max_depth: self.max_depth,
1228            min_samples_leaf: self.min_samples_leaf,
1229            max_leaf_nodes: self.max_leaf_nodes,
1230            l2_regularization: l2_reg,
1231        };
1232
1233        let mut trees = Vec::with_capacity(self.n_estimators);
1234
1235        for _ in 0..self.n_estimators {
1236            // Compute gradients and hessians.
1237            let (gradients, hessians) = match self.loss {
1238                HistRegressionLoss::LeastSquares => {
1239                    let grads: Vec<F> = (0..n_samples).map(|i| -(y[i] - f_vals[i])).collect();
1240                    let hess: Vec<F> = vec![F::one(); n_samples];
1241                    (grads, hess)
1242                }
1243                HistRegressionLoss::LeastAbsoluteDeviation => {
1244                    let grads: Vec<F> = (0..n_samples)
1245                        .map(|i| {
1246                            let diff = y[i] - f_vals[i];
1247                            if diff > F::zero() {
1248                                -F::one()
1249                            } else if diff < F::zero() {
1250                                F::one()
1251                            } else {
1252                                F::zero()
1253                            }
1254                        })
1255                        .collect();
1256                    let hess: Vec<F> = vec![F::one(); n_samples];
1257                    (grads, hess)
1258                }
1259            };
1260
1261            let tree = build_hist_tree(
1262                &binned,
1263                &gradients,
1264                &hessians,
1265                &all_indices,
1266                &bin_infos,
1267                &tree_params,
1268            );
1269
1270            // Update predictions.
1271            for i in 0..n_samples {
1272                let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1273                if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1274                    f_vals[i] = f_vals[i] + lr * value;
1275                }
1276            }
1277
1278            trees.push(tree);
1279        }
1280
1281        // Compute feature importances.
1282        let mut total_importances = Array1::<F>::zeros(n_features);
1283        for tree_nodes in &trees {
1284            total_importances =
1285                total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1286        }
1287        let imp_sum: F = total_importances
1288            .iter()
1289            .copied()
1290            .fold(F::zero(), |a, b| a + b);
1291        if imp_sum > F::zero() {
1292            total_importances.mapv_inplace(|v| v / imp_sum);
1293        }
1294
1295        Ok(FittedHistGradientBoostingRegressor {
1296            bin_infos,
1297            init,
1298            learning_rate: lr,
1299            trees,
1300            n_features,
1301            feature_importances: total_importances,
1302        })
1303    }
1304}
1305
1306impl<F: Float + Send + Sync + 'static> Predict<Array2<F>>
1307    for FittedHistGradientBoostingRegressor<F>
1308{
1309    type Output = Array1<F>;
1310    type Error = FerroError;
1311
1312    /// Predict target values.
1313    ///
1314    /// # Errors
1315    ///
1316    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
1317    /// not match the fitted model.
1318    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1319        if x.ncols() != self.n_features {
1320            return Err(FerroError::ShapeMismatch {
1321                expected: vec![self.n_features],
1322                actual: vec![x.ncols()],
1323                context: "number of features must match fitted model".into(),
1324            });
1325        }
1326
1327        let n_samples = x.nrows();
1328        let binned = bin_data(x, &self.bin_infos);
1329        let mut predictions = Array1::from_elem(n_samples, self.init);
1330
1331        for i in 0..n_samples {
1332            for tree_nodes in &self.trees {
1333                let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1334                if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1335                    predictions[i] = predictions[i] + self.learning_rate * value;
1336                }
1337            }
1338        }
1339
1340        Ok(predictions)
1341    }
1342}
1343
1344impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1345    for FittedHistGradientBoostingRegressor<F>
1346{
1347    fn feature_importances(&self) -> &Array1<F> {
1348        &self.feature_importances
1349    }
1350}
1351
1352// Pipeline integration.
1353impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for HistGradientBoostingRegressor<F> {
1354    fn fit_pipeline(
1355        &self,
1356        x: &Array2<F>,
1357        y: &Array1<F>,
1358    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1359        let fitted = self.fit(x, y)?;
1360        Ok(Box::new(fitted))
1361    }
1362}
1363
1364impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
1365    for FittedHistGradientBoostingRegressor<F>
1366{
1367    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1368        self.predict(x)
1369    }
1370}
1371
1372// ---------------------------------------------------------------------------
1373// HistGradientBoostingClassifier
1374// ---------------------------------------------------------------------------
1375
1376/// Histogram-based gradient boosting classifier.
1377///
1378/// For binary classification a single model is trained on log-odds residuals.
1379/// For multiclass (*K* classes), *K* histogram trees are built per boosting
1380/// round (one-vs-rest in probability space via softmax).
1381///
1382/// # Type Parameters
1383///
1384/// - `F`: The floating-point type (`f32` or `f64`).
1385#[derive(Debug, Clone, Serialize, Deserialize)]
1386pub struct HistGradientBoostingClassifier<F> {
1387    /// Number of boosting stages.
1388    pub n_estimators: usize,
1389    /// Learning rate (shrinkage).
1390    pub learning_rate: f64,
1391    /// Maximum depth of each tree.
1392    pub max_depth: Option<usize>,
1393    /// Minimum number of samples required in a leaf node.
1394    pub min_samples_leaf: usize,
1395    /// Maximum number of bins for feature discretisation (at most 256).
1396    pub max_bins: u16,
1397    /// L2 regularization term on weights.
1398    pub l2_regularization: f64,
1399    /// Maximum number of leaf nodes per tree (best-first growth).
1400    pub max_leaf_nodes: Option<usize>,
1401    /// Classification loss function.
1402    pub loss: HistClassificationLoss,
1403    /// Random seed for reproducibility (reserved for future subsampling).
1404    pub random_state: Option<u64>,
1405    _marker: std::marker::PhantomData<F>,
1406}
1407
1408impl<F: Float> HistGradientBoostingClassifier<F> {
1409    /// Create a new `HistGradientBoostingClassifier` with default settings.
1410    ///
1411    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
1412    /// `max_depth = None`, `min_samples_leaf = 20`,
1413    /// `max_bins = 255`, `l2_regularization = 0.0`,
1414    /// `max_leaf_nodes = Some(31)`, `loss = LogLoss`.
1415    #[must_use]
1416    pub fn new() -> Self {
1417        Self {
1418            n_estimators: 100,
1419            learning_rate: 0.1,
1420            max_depth: None,
1421            min_samples_leaf: 20,
1422            max_bins: 255,
1423            l2_regularization: 0.0,
1424            max_leaf_nodes: Some(31),
1425            loss: HistClassificationLoss::LogLoss,
1426            random_state: None,
1427            _marker: std::marker::PhantomData,
1428        }
1429    }
1430
1431    /// Set the number of boosting stages.
1432    #[must_use]
1433    pub fn with_n_estimators(mut self, n: usize) -> Self {
1434        self.n_estimators = n;
1435        self
1436    }
1437
1438    /// Set the learning rate (shrinkage).
1439    #[must_use]
1440    pub fn with_learning_rate(mut self, lr: f64) -> Self {
1441        self.learning_rate = lr;
1442        self
1443    }
1444
1445    /// Set the maximum tree depth.
1446    #[must_use]
1447    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
1448        self.max_depth = d;
1449        self
1450    }
1451
1452    /// Set the minimum number of samples in a leaf.
1453    #[must_use]
1454    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
1455        self.min_samples_leaf = n;
1456        self
1457    }
1458
1459    /// Set the maximum number of bins for feature discretisation.
1460    #[must_use]
1461    pub fn with_max_bins(mut self, bins: u16) -> Self {
1462        self.max_bins = bins;
1463        self
1464    }
1465
1466    /// Set the L2 regularization term.
1467    #[must_use]
1468    pub fn with_l2_regularization(mut self, reg: f64) -> Self {
1469        self.l2_regularization = reg;
1470        self
1471    }
1472
1473    /// Set the maximum number of leaf nodes (best-first growth).
1474    #[must_use]
1475    pub fn with_max_leaf_nodes(mut self, n: Option<usize>) -> Self {
1476        self.max_leaf_nodes = n;
1477        self
1478    }
1479
1480    /// Set the random seed for reproducibility.
1481    #[must_use]
1482    pub fn with_random_state(mut self, seed: u64) -> Self {
1483        self.random_state = Some(seed);
1484        self
1485    }
1486}
1487
1488impl<F: Float> Default for HistGradientBoostingClassifier<F> {
1489    fn default() -> Self {
1490        Self::new()
1491    }
1492}
1493
1494// ---------------------------------------------------------------------------
1495// FittedHistGradientBoostingClassifier
1496// ---------------------------------------------------------------------------
1497
1498/// A fitted histogram-based gradient boosting classifier.
1499///
1500/// For binary classification, stores a single sequence of trees predicting log-odds.
1501/// For multiclass, stores `K` sequences of trees (one per class).
1502#[derive(Debug, Clone)]
1503pub struct FittedHistGradientBoostingClassifier<F> {
1504    /// Bin edge information for each feature.
1505    bin_infos: Vec<FeatureBinInfo<F>>,
1506    /// Sorted unique class labels.
1507    classes: Vec<usize>,
1508    /// Initial predictions per class (log-odds or log-prior).
1509    init: Vec<F>,
1510    /// Learning rate.
1511    learning_rate: F,
1512    /// Trees: for binary, `trees[0]` has all trees. For multiclass,
1513    /// `trees[k]` has trees for class k.
1514    trees: Vec<Vec<Vec<HistNode<F>>>>,
1515    /// Number of features.
1516    n_features: usize,
1517    /// Per-feature importance scores (normalised).
1518    feature_importances: Array1<F>,
1519}
1520
1521impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>>
1522    for HistGradientBoostingClassifier<F>
1523{
1524    type Fitted = FittedHistGradientBoostingClassifier<F>;
1525    type Error = FerroError;
1526
1527    /// Fit the histogram-based gradient boosting classifier.
1528    ///
1529    /// # Errors
1530    ///
1531    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
1532    /// numbers of samples.
1533    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
1534    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
1535    fn fit(
1536        &self,
1537        x: &Array2<F>,
1538        y: &Array1<usize>,
1539    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1540        let (n_samples, n_features) = x.dim();
1541
1542        if n_samples != y.len() {
1543            return Err(FerroError::ShapeMismatch {
1544                expected: vec![n_samples],
1545                actual: vec![y.len()],
1546                context: "y length must match number of samples in X".into(),
1547            });
1548        }
1549        if n_samples == 0 {
1550            return Err(FerroError::InsufficientSamples {
1551                required: 1,
1552                actual: 0,
1553                context: "HistGradientBoostingClassifier requires at least one sample".into(),
1554            });
1555        }
1556        if self.n_estimators == 0 {
1557            return Err(FerroError::InvalidParameter {
1558                name: "n_estimators".into(),
1559                reason: "must be at least 1".into(),
1560            });
1561        }
1562        if self.learning_rate <= 0.0 {
1563            return Err(FerroError::InvalidParameter {
1564                name: "learning_rate".into(),
1565                reason: "must be positive".into(),
1566            });
1567        }
1568        if self.max_bins < 2 {
1569            return Err(FerroError::InvalidParameter {
1570                name: "max_bins".into(),
1571                reason: "must be at least 2".into(),
1572            });
1573        }
1574
1575        // Determine unique classes.
1576        let mut classes: Vec<usize> = y.iter().copied().collect();
1577        classes.sort_unstable();
1578        classes.dedup();
1579        let n_classes = classes.len();
1580
1581        if n_classes < 2 {
1582            return Err(FerroError::InvalidParameter {
1583                name: "y".into(),
1584                reason: "need at least 2 distinct classes".into(),
1585            });
1586        }
1587
1588        let y_mapped: Vec<usize> = y
1589            .iter()
1590            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
1591            .collect();
1592
1593        let lr = F::from(self.learning_rate).unwrap();
1594        let l2_reg = F::from(self.l2_regularization).unwrap();
1595
1596        // Bin the data.
1597        let bin_infos = compute_bin_edges(x, self.max_bins);
1598        let binned = bin_data(x, &bin_infos);
1599
1600        let tree_params = HistTreeParams {
1601            max_depth: self.max_depth,
1602            min_samples_leaf: self.min_samples_leaf,
1603            max_leaf_nodes: self.max_leaf_nodes,
1604            l2_regularization: l2_reg,
1605        };
1606
1607        let all_indices: Vec<usize> = (0..n_samples).collect();
1608
1609        if n_classes == 2 {
1610            self.fit_binary(
1611                &binned,
1612                &y_mapped,
1613                n_samples,
1614                n_features,
1615                &classes,
1616                lr,
1617                &tree_params,
1618                &all_indices,
1619                &bin_infos,
1620            )
1621        } else {
1622            self.fit_multiclass(
1623                &binned,
1624                &y_mapped,
1625                n_samples,
1626                n_features,
1627                n_classes,
1628                &classes,
1629                lr,
1630                &tree_params,
1631                &all_indices,
1632                &bin_infos,
1633            )
1634        }
1635    }
1636}
1637
1638impl<F: Float + Send + Sync + 'static> HistGradientBoostingClassifier<F> {
1639    /// Fit binary classification (log-loss on log-odds).
1640    #[allow(clippy::too_many_arguments)]
1641    fn fit_binary(
1642        &self,
1643        binned: &[Vec<u16>],
1644        y_mapped: &[usize],
1645        n_samples: usize,
1646        n_features: usize,
1647        classes: &[usize],
1648        lr: F,
1649        tree_params: &HistTreeParams<F>,
1650        all_indices: &[usize],
1651        bin_infos: &[FeatureBinInfo<F>],
1652    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1653        // Initial log-odds.
1654        let pos_count = y_mapped.iter().filter(|&&c| c == 1).count();
1655        let p = F::from(pos_count).unwrap() / F::from(n_samples).unwrap();
1656        let eps = F::from(1e-15).unwrap();
1657        let p_clipped = p.max(eps).min(F::one() - eps);
1658        let init_val = (p_clipped / (F::one() - p_clipped)).ln();
1659
1660        let mut f_vals = Array1::from_elem(n_samples, init_val);
1661        let mut trees_seq: Vec<Vec<HistNode<F>>> = Vec::with_capacity(self.n_estimators);
1662
1663        for _ in 0..self.n_estimators {
1664            // Compute probabilities.
1665            let probs: Vec<F> = f_vals.iter().map(|&fv| sigmoid(fv)).collect();
1666
1667            // Gradients and hessians for log-loss:
1668            //   gradient = p - y (we negate pseudo-residual because tree fits -gradient)
1669            //   hessian = p * (1 - p)
1670            let gradients: Vec<F> = (0..n_samples)
1671                .map(|i| {
1672                    let yi = F::from(y_mapped[i]).unwrap();
1673                    probs[i] - yi
1674                })
1675                .collect();
1676            let hessians: Vec<F> = (0..n_samples)
1677                .map(|i| {
1678                    let pi = probs[i].max(eps).min(F::one() - eps);
1679                    pi * (F::one() - pi)
1680                })
1681                .collect();
1682
1683            let tree = build_hist_tree(
1684                binned,
1685                &gradients,
1686                &hessians,
1687                all_indices,
1688                bin_infos,
1689                tree_params,
1690            );
1691
1692            // Update f_vals.
1693            for i in 0..n_samples {
1694                let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1695                if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1696                    f_vals[i] = f_vals[i] + lr * value;
1697                }
1698            }
1699
1700            trees_seq.push(tree);
1701        }
1702
1703        // Feature importances.
1704        let mut total_importances = Array1::<F>::zeros(n_features);
1705        for tree_nodes in &trees_seq {
1706            total_importances =
1707                total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1708        }
1709        let imp_sum: F = total_importances
1710            .iter()
1711            .copied()
1712            .fold(F::zero(), |a, b| a + b);
1713        if imp_sum > F::zero() {
1714            total_importances.mapv_inplace(|v| v / imp_sum);
1715        }
1716
1717        Ok(FittedHistGradientBoostingClassifier {
1718            bin_infos: bin_infos.to_vec(),
1719            classes: classes.to_vec(),
1720            init: vec![init_val],
1721            learning_rate: lr,
1722            trees: vec![trees_seq],
1723            n_features,
1724            feature_importances: total_importances,
1725        })
1726    }
1727
1728    /// Fit multiclass classification (K trees per round, softmax).
1729    #[allow(clippy::too_many_arguments)]
1730    fn fit_multiclass(
1731        &self,
1732        binned: &[Vec<u16>],
1733        y_mapped: &[usize],
1734        n_samples: usize,
1735        n_features: usize,
1736        n_classes: usize,
1737        classes: &[usize],
1738        lr: F,
1739        tree_params: &HistTreeParams<F>,
1740        all_indices: &[usize],
1741        bin_infos: &[FeatureBinInfo<F>],
1742    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1743        // Initial log-prior for each class.
1744        let mut class_counts = vec![0usize; n_classes];
1745        for &c in y_mapped {
1746            class_counts[c] += 1;
1747        }
1748        let n_f = F::from(n_samples).unwrap();
1749        let eps = F::from(1e-15).unwrap();
1750        let init_vals: Vec<F> = class_counts
1751            .iter()
1752            .map(|&cnt| {
1753                let p = (F::from(cnt).unwrap() / n_f).max(eps);
1754                p.ln()
1755            })
1756            .collect();
1757
1758        let mut f_vals: Vec<Array1<F>> = init_vals
1759            .iter()
1760            .map(|&init| Array1::from_elem(n_samples, init))
1761            .collect();
1762
1763        let mut trees_per_class: Vec<Vec<Vec<HistNode<F>>>> = (0..n_classes)
1764            .map(|_| Vec::with_capacity(self.n_estimators))
1765            .collect();
1766
1767        for _ in 0..self.n_estimators {
1768            let probs = softmax_matrix(&f_vals, n_samples, n_classes);
1769
1770            for k in 0..n_classes {
1771                // Gradients and hessians for softmax cross-entropy:
1772                //   gradient_k = p_k - y_k
1773                //   hessian_k  = p_k * (1 - p_k)
1774                let gradients: Vec<F> = (0..n_samples)
1775                    .map(|i| {
1776                        let yi_k = if y_mapped[i] == k {
1777                            F::one()
1778                        } else {
1779                            F::zero()
1780                        };
1781                        probs[k][i] - yi_k
1782                    })
1783                    .collect();
1784                let hessians: Vec<F> = (0..n_samples)
1785                    .map(|i| {
1786                        let pk = probs[k][i].max(eps).min(F::one() - eps);
1787                        pk * (F::one() - pk)
1788                    })
1789                    .collect();
1790
1791                let tree = build_hist_tree(
1792                    binned,
1793                    &gradients,
1794                    &hessians,
1795                    all_indices,
1796                    bin_infos,
1797                    tree_params,
1798                );
1799
1800                // Update f_vals for class k.
1801                for (i, fv) in f_vals[k].iter_mut().enumerate() {
1802                    let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1803                    if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1804                        *fv = *fv + lr * value;
1805                    }
1806                }
1807
1808                trees_per_class[k].push(tree);
1809            }
1810        }
1811
1812        // Feature importances aggregated across all classes and rounds.
1813        let mut total_importances = Array1::<F>::zeros(n_features);
1814        for class_trees in &trees_per_class {
1815            for tree_nodes in class_trees {
1816                total_importances =
1817                    total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1818            }
1819        }
1820        let imp_sum: F = total_importances
1821            .iter()
1822            .copied()
1823            .fold(F::zero(), |a, b| a + b);
1824        if imp_sum > F::zero() {
1825            total_importances.mapv_inplace(|v| v / imp_sum);
1826        }
1827
1828        Ok(FittedHistGradientBoostingClassifier {
1829            bin_infos: bin_infos.to_vec(),
1830            classes: classes.to_vec(),
1831            init: init_vals,
1832            learning_rate: lr,
1833            trees: trees_per_class,
1834            n_features,
1835            feature_importances: total_importances,
1836        })
1837    }
1838}
1839
1840impl<F: Float + Send + Sync + 'static> Predict<Array2<F>>
1841    for FittedHistGradientBoostingClassifier<F>
1842{
1843    type Output = Array1<usize>;
1844    type Error = FerroError;
1845
1846    /// Predict class labels.
1847    ///
1848    /// # Errors
1849    ///
1850    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
1851    /// not match the fitted model.
1852    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
1853        if x.ncols() != self.n_features {
1854            return Err(FerroError::ShapeMismatch {
1855                expected: vec![self.n_features],
1856                actual: vec![x.ncols()],
1857                context: "number of features must match fitted model".into(),
1858            });
1859        }
1860
1861        let n_samples = x.nrows();
1862        let n_classes = self.classes.len();
1863        let binned = bin_data(x, &self.bin_infos);
1864
1865        if n_classes == 2 {
1866            let init = self.init[0];
1867            let mut predictions = Array1::zeros(n_samples);
1868            for i in 0..n_samples {
1869                let mut f_val = init;
1870                for tree_nodes in &self.trees[0] {
1871                    let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1872                    if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1873                        f_val = f_val + self.learning_rate * value;
1874                    }
1875                }
1876                let prob = sigmoid(f_val);
1877                let class_idx = if prob >= F::from(0.5).unwrap() { 1 } else { 0 };
1878                predictions[i] = self.classes[class_idx];
1879            }
1880            Ok(predictions)
1881        } else {
1882            let mut predictions = Array1::zeros(n_samples);
1883            for i in 0..n_samples {
1884                let mut scores = Vec::with_capacity(n_classes);
1885                for k in 0..n_classes {
1886                    let mut f_val = self.init[k];
1887                    for tree_nodes in &self.trees[k] {
1888                        let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1889                        if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1890                            f_val = f_val + self.learning_rate * value;
1891                        }
1892                    }
1893                    scores.push(f_val);
1894                }
1895                let best_k = scores
1896                    .iter()
1897                    .enumerate()
1898                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1899                    .map_or(0, |(k, _)| k);
1900                predictions[i] = self.classes[best_k];
1901            }
1902            Ok(predictions)
1903        }
1904    }
1905}
1906
1907impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1908    for FittedHistGradientBoostingClassifier<F>
1909{
1910    fn feature_importances(&self) -> &Array1<F> {
1911        &self.feature_importances
1912    }
1913}
1914
1915impl<F: Float + Send + Sync + 'static> HasClasses for FittedHistGradientBoostingClassifier<F> {
1916    fn classes(&self) -> &[usize] {
1917        &self.classes
1918    }
1919
1920    fn n_classes(&self) -> usize {
1921        self.classes.len()
1922    }
1923}
1924
1925// Pipeline integration.
1926impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
1927    for HistGradientBoostingClassifier<F>
1928{
1929    fn fit_pipeline(
1930        &self,
1931        x: &Array2<F>,
1932        y: &Array1<F>,
1933    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1934        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
1935        let fitted = self.fit(x, &y_usize)?;
1936        Ok(Box::new(FittedHgbcPipelineAdapter(fitted)))
1937    }
1938}
1939
1940/// Pipeline adapter for `FittedHistGradientBoostingClassifier<F>`.
1941struct FittedHgbcPipelineAdapter<F: Float + Send + Sync + 'static>(
1942    FittedHistGradientBoostingClassifier<F>,
1943);
1944
1945impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
1946    for FittedHgbcPipelineAdapter<F>
1947{
1948    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1949        let preds = self.0.predict(x)?;
1950        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
1951    }
1952}
1953
1954// ---------------------------------------------------------------------------
1955// Internal helpers
1956// ---------------------------------------------------------------------------
1957
1958/// Compute the median of an `Array1`.
1959fn median_f<F: Float>(arr: &Array1<F>) -> F {
1960    let mut sorted: Vec<F> = arr.iter().copied().collect();
1961    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1962    let n = sorted.len();
1963    if n == 0 {
1964        return F::zero();
1965    }
1966    if n % 2 == 1 {
1967        sorted[n / 2]
1968    } else {
1969        (sorted[n / 2 - 1] + sorted[n / 2]) / F::from(2.0).unwrap()
1970    }
1971}
1972
1973// ---------------------------------------------------------------------------
1974// Tests
1975// ---------------------------------------------------------------------------
1976
1977#[cfg(test)]
1978mod tests {
1979    use super::*;
1980    use approx::assert_relative_eq;
1981    use ndarray::array;
1982
1983    // -- Binning tests --
1984
1985    #[test]
1986    fn test_bin_edges_simple() {
1987        let x =
1988            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1989        let infos = compute_bin_edges(&x, 4);
1990        assert_eq!(infos.len(), 1);
1991        // Should have up to 4 bins.
1992        assert!(infos[0].n_bins <= 4);
1993        assert!(infos[0].n_bins >= 2);
1994        assert!(!infos[0].has_nan);
1995    }
1996
1997    #[test]
1998    fn test_bin_edges_with_nan() {
1999        let x = Array2::from_shape_vec((5, 1), vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0]).unwrap();
2000        let infos = compute_bin_edges(&x, 4);
2001        assert_eq!(infos.len(), 1);
2002        assert!(infos[0].has_nan);
2003        assert!(infos[0].n_bins >= 1);
2004    }
2005
2006    #[test]
2007    fn test_bin_edges_all_nan() {
2008        let x = Array2::from_shape_vec((3, 1), vec![f64::NAN, f64::NAN, f64::NAN]).unwrap();
2009        let infos = compute_bin_edges(&x, 4);
2010        assert_eq!(infos[0].n_bins, 0);
2011        assert!(infos[0].has_nan);
2012    }
2013
2014    #[test]
2015    fn test_map_to_bin_basic() {
2016        let info = FeatureBinInfo {
2017            edges: vec![2.0, 4.0, 6.0, 8.0],
2018            n_bins: 4,
2019            has_nan: false,
2020        };
2021        assert_eq!(map_to_bin(1.0, &info), 0);
2022        assert_eq!(map_to_bin(2.0, &info), 0);
2023        assert_eq!(map_to_bin(3.0, &info), 1);
2024        assert_eq!(map_to_bin(5.0, &info), 2);
2025        assert_eq!(map_to_bin(7.0, &info), 3);
2026        assert_eq!(map_to_bin(9.0, &info), 3);
2027        assert_eq!(map_to_bin(f64::NAN, &info), NAN_BIN);
2028    }
2029
2030    #[test]
2031    fn test_bin_data_roundtrip() {
2032        let x = Array2::from_shape_vec((4, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0])
2033            .unwrap();
2034        let infos = compute_bin_edges(&x, 255);
2035        let binned = bin_data(&x, &infos);
2036        assert_eq!(binned.len(), 4);
2037        assert_eq!(binned[0].len(), 2);
2038        // Values should be monotonically non-decreasing since input is sorted.
2039        for window in binned.windows(2) {
2040            let (prev, curr) = (&window[0], &window[1]);
2041            for (&p, &c) in prev.iter().zip(curr.iter()) {
2042                assert!(c >= p);
2043            }
2044        }
2045    }
2046
2047    // -- Subtraction trick test --
2048
2049    #[test]
2050    fn test_subtraction_trick() {
2051        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2052        let bin_infos = compute_bin_edges(&x, 255);
2053        let binned = bin_data(&x, &bin_infos);
2054
2055        let gradients = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
2056        let hessians = vec![1.0; 6];
2057        let all_indices: Vec<usize> = (0..6).collect();
2058        let left_indices: Vec<usize> = vec![0, 1, 2];
2059        let right_indices: Vec<usize> = vec![3, 4, 5];
2060
2061        let parent_hist =
2062            build_histograms(&binned, &gradients, &hessians, &all_indices, &bin_infos);
2063        let left_hist = build_histograms(&binned, &gradients, &hessians, &left_indices, &bin_infos);
2064        let right_from_sub = subtract_histograms(&parent_hist, &left_hist, &bin_infos);
2065        let right_direct =
2066            build_histograms(&binned, &gradients, &hessians, &right_indices, &bin_infos);
2067
2068        // The subtraction trick result should match direct computation.
2069        for j in 0..bin_infos.len() {
2070            let n_bins = bin_infos[j].n_bins as usize;
2071            for b in 0..n_bins {
2072                assert_relative_eq!(
2073                    right_from_sub[j].bins[b].grad_sum,
2074                    right_direct[j].bins[b].grad_sum,
2075                    epsilon = 1e-10
2076                );
2077                assert_relative_eq!(
2078                    right_from_sub[j].bins[b].hess_sum,
2079                    right_direct[j].bins[b].hess_sum,
2080                    epsilon = 1e-10
2081                );
2082                assert_eq!(
2083                    right_from_sub[j].bins[b].count,
2084                    right_direct[j].bins[b].count
2085                );
2086            }
2087        }
2088    }
2089
2090    // -- Regressor tests --
2091
2092    #[test]
2093    fn test_hgbr_simple_least_squares() {
2094        let x =
2095            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2096        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2097
2098        let model = HistGradientBoostingRegressor::<f64>::new()
2099            .with_n_estimators(50)
2100            .with_learning_rate(0.1)
2101            .with_min_samples_leaf(1)
2102            .with_max_leaf_nodes(None)
2103            .with_max_depth(Some(3))
2104            .with_random_state(42);
2105        let fitted = model.fit(&x, &y).unwrap();
2106        let preds = fitted.predict(&x).unwrap();
2107
2108        assert_eq!(preds.len(), 8);
2109        for i in 0..4 {
2110            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
2111        }
2112        for i in 4..8 {
2113            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
2114        }
2115    }
2116
2117    #[test]
2118    fn test_hgbr_lad_loss() {
2119        let x =
2120            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2121        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2122
2123        let model = HistGradientBoostingRegressor::<f64>::new()
2124            .with_n_estimators(50)
2125            .with_loss(HistRegressionLoss::LeastAbsoluteDeviation)
2126            .with_min_samples_leaf(1)
2127            .with_max_leaf_nodes(None)
2128            .with_max_depth(Some(3))
2129            .with_random_state(42);
2130        let fitted = model.fit(&x, &y).unwrap();
2131        let preds = fitted.predict(&x).unwrap();
2132
2133        assert_eq!(preds.len(), 8);
2134        for i in 0..4 {
2135            assert!(preds[i] < 3.5, "LAD expected <3.5, got {}", preds[i]);
2136        }
2137        for i in 4..8 {
2138            assert!(preds[i] > 2.5, "LAD expected >2.5, got {}", preds[i]);
2139        }
2140    }
2141
2142    #[test]
2143    fn test_hgbr_reproducibility() {
2144        let x =
2145            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2146        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2147
2148        let model = HistGradientBoostingRegressor::<f64>::new()
2149            .with_n_estimators(20)
2150            .with_min_samples_leaf(1)
2151            .with_max_leaf_nodes(None)
2152            .with_max_depth(Some(3))
2153            .with_random_state(123);
2154
2155        let fitted1 = model.fit(&x, &y).unwrap();
2156        let fitted2 = model.fit(&x, &y).unwrap();
2157
2158        let preds1 = fitted1.predict(&x).unwrap();
2159        let preds2 = fitted2.predict(&x).unwrap();
2160
2161        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
2162            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
2163        }
2164    }
2165
2166    #[test]
2167    fn test_hgbr_feature_importances() {
2168        let x = Array2::from_shape_vec(
2169            (10, 3),
2170            vec![
2171                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
2172                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
2173            ],
2174        )
2175        .unwrap();
2176        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
2177
2178        let model = HistGradientBoostingRegressor::<f64>::new()
2179            .with_n_estimators(20)
2180            .with_min_samples_leaf(1)
2181            .with_max_leaf_nodes(None)
2182            .with_max_depth(Some(3))
2183            .with_random_state(42);
2184        let fitted = model.fit(&x, &y).unwrap();
2185        let importances = fitted.feature_importances();
2186
2187        assert_eq!(importances.len(), 3);
2188        // First feature should be most important since it's the only one with variance.
2189        assert!(
2190            importances[0] > importances[1],
2191            "Expected imp[0]={} > imp[1]={}",
2192            importances[0],
2193            importances[1]
2194        );
2195    }
2196
2197    #[test]
2198    fn test_hgbr_shape_mismatch_fit() {
2199        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2200        let y = array![1.0, 2.0];
2201
2202        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(5);
2203        assert!(model.fit(&x, &y).is_err());
2204    }
2205
2206    #[test]
2207    fn test_hgbr_shape_mismatch_predict() {
2208        let x =
2209            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2210        let y = array![1.0, 2.0, 3.0, 4.0];
2211
2212        let model = HistGradientBoostingRegressor::<f64>::new()
2213            .with_n_estimators(5)
2214            .with_min_samples_leaf(1)
2215            .with_max_leaf_nodes(None)
2216            .with_max_depth(Some(3))
2217            .with_random_state(0);
2218        let fitted = model.fit(&x, &y).unwrap();
2219
2220        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2221        assert!(fitted.predict(&x_bad).is_err());
2222    }
2223
2224    #[test]
2225    fn test_hgbr_empty_data() {
2226        let x = Array2::<f64>::zeros((0, 2));
2227        let y = Array1::<f64>::zeros(0);
2228
2229        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(5);
2230        assert!(model.fit(&x, &y).is_err());
2231    }
2232
2233    #[test]
2234    fn test_hgbr_zero_estimators() {
2235        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2236        let y = array![1.0, 2.0, 3.0, 4.0];
2237
2238        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(0);
2239        assert!(model.fit(&x, &y).is_err());
2240    }
2241
2242    #[test]
2243    fn test_hgbr_invalid_learning_rate() {
2244        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2245        let y = array![1.0, 2.0, 3.0, 4.0];
2246
2247        let model = HistGradientBoostingRegressor::<f64>::new()
2248            .with_n_estimators(5)
2249            .with_learning_rate(0.0);
2250        assert!(model.fit(&x, &y).is_err());
2251    }
2252
2253    #[test]
2254    fn test_hgbr_invalid_max_bins() {
2255        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2256        let y = array![1.0, 2.0, 3.0, 4.0];
2257
2258        let model = HistGradientBoostingRegressor::<f64>::new()
2259            .with_n_estimators(5)
2260            .with_max_bins(1);
2261        assert!(model.fit(&x, &y).is_err());
2262    }
2263
2264    #[test]
2265    fn test_hgbr_default_trait() {
2266        let model = HistGradientBoostingRegressor::<f64>::default();
2267        assert_eq!(model.n_estimators, 100);
2268        assert!((model.learning_rate - 0.1).abs() < 1e-10);
2269        assert_eq!(model.max_bins, 255);
2270    }
2271
2272    #[test]
2273    fn test_hgbr_pipeline_integration() {
2274        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2275        let y = array![1.0, 2.0, 3.0, 4.0];
2276
2277        let model = HistGradientBoostingRegressor::<f64>::new()
2278            .with_n_estimators(10)
2279            .with_min_samples_leaf(1)
2280            .with_max_leaf_nodes(None)
2281            .with_max_depth(Some(3))
2282            .with_random_state(42);
2283        let fitted = model.fit_pipeline(&x, &y).unwrap();
2284        let preds = fitted.predict_pipeline(&x).unwrap();
2285        assert_eq!(preds.len(), 4);
2286    }
2287
2288    #[test]
2289    fn test_hgbr_f32_support() {
2290        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
2291        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
2292
2293        let model = HistGradientBoostingRegressor::<f32>::new()
2294            .with_n_estimators(10)
2295            .with_min_samples_leaf(1)
2296            .with_max_leaf_nodes(None)
2297            .with_max_depth(Some(3))
2298            .with_random_state(42);
2299        let fitted = model.fit(&x, &y).unwrap();
2300        let preds = fitted.predict(&x).unwrap();
2301        assert_eq!(preds.len(), 4);
2302    }
2303
2304    #[test]
2305    fn test_hgbr_nan_handling() {
2306        // Some features have NaN — the model should handle them gracefully.
2307        let x = Array2::from_shape_vec(
2308            (8, 1),
2309            vec![1.0, f64::NAN, 3.0, 4.0, 5.0, f64::NAN, 7.0, 8.0],
2310        )
2311        .unwrap();
2312        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2313
2314        let model = HistGradientBoostingRegressor::<f64>::new()
2315            .with_n_estimators(20)
2316            .with_min_samples_leaf(1)
2317            .with_max_leaf_nodes(None)
2318            .with_max_depth(Some(3))
2319            .with_random_state(42);
2320        let fitted = model.fit(&x, &y).unwrap();
2321        let preds = fitted.predict(&x).unwrap();
2322        assert_eq!(preds.len(), 8);
2323        // Should still produce finite predictions for all samples (including NaN inputs).
2324        for p in &preds {
2325            assert!(p.is_finite(), "Expected finite prediction, got {p}");
2326        }
2327    }
2328
2329    #[test]
2330    fn test_hgbr_convergence() {
2331        // MSE should decrease as we add more estimators.
2332        let x =
2333            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2334        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2335
2336        let mse = |preds: &Array1<f64>, y: &Array1<f64>| -> f64 {
2337            preds
2338                .iter()
2339                .zip(y.iter())
2340                .map(|(p, t)| (p - t).powi(2))
2341                .sum::<f64>()
2342                / y.len() as f64
2343        };
2344
2345        let model_few = HistGradientBoostingRegressor::<f64>::new()
2346            .with_n_estimators(5)
2347            .with_min_samples_leaf(1)
2348            .with_max_leaf_nodes(None)
2349            .with_max_depth(Some(3))
2350            .with_random_state(42);
2351        let fitted_few = model_few.fit(&x, &y).unwrap();
2352        let preds_few = fitted_few.predict(&x).unwrap();
2353        let mse_few = mse(&preds_few, &y);
2354
2355        let model_many = HistGradientBoostingRegressor::<f64>::new()
2356            .with_n_estimators(50)
2357            .with_min_samples_leaf(1)
2358            .with_max_leaf_nodes(None)
2359            .with_max_depth(Some(3))
2360            .with_random_state(42);
2361        let fitted_many = model_many.fit(&x, &y).unwrap();
2362        let preds_many = fitted_many.predict(&x).unwrap();
2363        let mse_many = mse(&preds_many, &y);
2364
2365        assert!(
2366            mse_many < mse_few,
2367            "Expected MSE to decrease with more estimators: {mse_many} (50) vs {mse_few} (5)"
2368        );
2369    }
2370
2371    #[test]
2372    fn test_hgbr_max_leaf_nodes() {
2373        // Test that best-first growth with max_leaf_nodes works.
2374        let x =
2375            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2376        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2377
2378        let model = HistGradientBoostingRegressor::<f64>::new()
2379            .with_n_estimators(20)
2380            .with_min_samples_leaf(1)
2381            .with_max_leaf_nodes(Some(4))
2382            .with_random_state(42);
2383        let fitted = model.fit(&x, &y).unwrap();
2384        let preds = fitted.predict(&x).unwrap();
2385        assert_eq!(preds.len(), 8);
2386    }
2387
2388    #[test]
2389    fn test_hgbr_l2_regularization() {
2390        let x =
2391            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2392        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2393
2394        // With very high regularization, predictions should be closer to the mean.
2395        let model_noreg = HistGradientBoostingRegressor::<f64>::new()
2396            .with_n_estimators(20)
2397            .with_min_samples_leaf(1)
2398            .with_max_leaf_nodes(None)
2399            .with_max_depth(Some(3))
2400            .with_l2_regularization(0.0)
2401            .with_random_state(42);
2402        let fitted_noreg = model_noreg.fit(&x, &y).unwrap();
2403        let preds_noreg = fitted_noreg.predict(&x).unwrap();
2404
2405        let model_highreg = HistGradientBoostingRegressor::<f64>::new()
2406            .with_n_estimators(20)
2407            .with_min_samples_leaf(1)
2408            .with_max_leaf_nodes(None)
2409            .with_max_depth(Some(3))
2410            .with_l2_regularization(100.0)
2411            .with_random_state(42);
2412        let fitted_highreg = model_highreg.fit(&x, &y).unwrap();
2413        let preds_highreg = fitted_highreg.predict(&x).unwrap();
2414
2415        // With high reg, variance of predictions should be smaller.
2416        let var = |preds: &Array1<f64>| -> f64 {
2417            let mean = preds.mean().unwrap();
2418            preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64
2419        };
2420
2421        assert!(
2422            var(&preds_highreg) < var(&preds_noreg),
2423            "High regularization should reduce prediction variance"
2424        );
2425    }
2426
2427    // -- Classifier tests --
2428
2429    #[test]
2430    fn test_hgbc_binary_simple() {
2431        let x = Array2::from_shape_vec(
2432            (8, 2),
2433            vec![
2434                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
2435            ],
2436        )
2437        .unwrap();
2438        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2439
2440        let model = HistGradientBoostingClassifier::<f64>::new()
2441            .with_n_estimators(50)
2442            .with_learning_rate(0.1)
2443            .with_min_samples_leaf(1)
2444            .with_max_leaf_nodes(None)
2445            .with_max_depth(Some(3))
2446            .with_random_state(42);
2447        let fitted = model.fit(&x, &y).unwrap();
2448        let preds = fitted.predict(&x).unwrap();
2449
2450        assert_eq!(preds.len(), 8);
2451        for i in 0..4 {
2452            assert_eq!(preds[i], 0, "Expected 0 at index {}, got {}", i, preds[i]);
2453        }
2454        for i in 4..8 {
2455            assert_eq!(preds[i], 1, "Expected 1 at index {}, got {}", i, preds[i]);
2456        }
2457    }
2458
2459    #[test]
2460    fn test_hgbc_multiclass() {
2461        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
2462            .unwrap();
2463        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
2464
2465        let model = HistGradientBoostingClassifier::<f64>::new()
2466            .with_n_estimators(50)
2467            .with_learning_rate(0.1)
2468            .with_min_samples_leaf(1)
2469            .with_max_leaf_nodes(None)
2470            .with_max_depth(Some(3))
2471            .with_random_state(42);
2472        let fitted = model.fit(&x, &y).unwrap();
2473        let preds = fitted.predict(&x).unwrap();
2474
2475        assert_eq!(preds.len(), 9);
2476        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
2477        assert!(
2478            correct >= 6,
2479            "Expected at least 6/9 correct, got {correct}/9"
2480        );
2481    }
2482
2483    #[test]
2484    fn test_hgbc_has_classes() {
2485        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2486        let y = array![0, 1, 2, 0, 1, 2];
2487
2488        let model = HistGradientBoostingClassifier::<f64>::new()
2489            .with_n_estimators(5)
2490            .with_min_samples_leaf(1)
2491            .with_max_leaf_nodes(None)
2492            .with_max_depth(Some(3))
2493            .with_random_state(0);
2494        let fitted = model.fit(&x, &y).unwrap();
2495
2496        assert_eq!(fitted.classes(), &[0, 1, 2]);
2497        assert_eq!(fitted.n_classes(), 3);
2498    }
2499
2500    #[test]
2501    fn test_hgbc_reproducibility() {
2502        let x = Array2::from_shape_vec(
2503            (8, 2),
2504            vec![
2505                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
2506            ],
2507        )
2508        .unwrap();
2509        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2510
2511        let model = HistGradientBoostingClassifier::<f64>::new()
2512            .with_n_estimators(10)
2513            .with_min_samples_leaf(1)
2514            .with_max_leaf_nodes(None)
2515            .with_max_depth(Some(3))
2516            .with_random_state(42);
2517
2518        let fitted1 = model.fit(&x, &y).unwrap();
2519        let fitted2 = model.fit(&x, &y).unwrap();
2520
2521        let preds1 = fitted1.predict(&x).unwrap();
2522        let preds2 = fitted2.predict(&x).unwrap();
2523        assert_eq!(preds1, preds2);
2524    }
2525
2526    #[test]
2527    fn test_hgbc_feature_importances() {
2528        let x = Array2::from_shape_vec(
2529            (10, 3),
2530            vec![
2531                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
2532                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
2533            ],
2534        )
2535        .unwrap();
2536        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
2537
2538        let model = HistGradientBoostingClassifier::<f64>::new()
2539            .with_n_estimators(20)
2540            .with_min_samples_leaf(1)
2541            .with_max_leaf_nodes(None)
2542            .with_max_depth(Some(3))
2543            .with_random_state(42);
2544        let fitted = model.fit(&x, &y).unwrap();
2545        let importances = fitted.feature_importances();
2546
2547        assert_eq!(importances.len(), 3);
2548        assert!(importances[0] > importances[1]);
2549    }
2550
2551    #[test]
2552    fn test_hgbc_shape_mismatch_fit() {
2553        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2554        let y = array![0, 1];
2555
2556        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2557        assert!(model.fit(&x, &y).is_err());
2558    }
2559
2560    #[test]
2561    fn test_hgbc_shape_mismatch_predict() {
2562        let x =
2563            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2564        let y = array![0, 0, 1, 1];
2565
2566        let model = HistGradientBoostingClassifier::<f64>::new()
2567            .with_n_estimators(5)
2568            .with_min_samples_leaf(1)
2569            .with_max_leaf_nodes(None)
2570            .with_max_depth(Some(3))
2571            .with_random_state(0);
2572        let fitted = model.fit(&x, &y).unwrap();
2573
2574        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2575        assert!(fitted.predict(&x_bad).is_err());
2576    }
2577
2578    #[test]
2579    fn test_hgbc_empty_data() {
2580        let x = Array2::<f64>::zeros((0, 2));
2581        let y = Array1::<usize>::zeros(0);
2582
2583        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2584        assert!(model.fit(&x, &y).is_err());
2585    }
2586
2587    #[test]
2588    fn test_hgbc_single_class() {
2589        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
2590        let y = array![0, 0, 0];
2591
2592        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2593        assert!(model.fit(&x, &y).is_err());
2594    }
2595
2596    #[test]
2597    fn test_hgbc_zero_estimators() {
2598        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2599        let y = array![0, 0, 1, 1];
2600
2601        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(0);
2602        assert!(model.fit(&x, &y).is_err());
2603    }
2604
2605    #[test]
2606    fn test_hgbc_pipeline_integration() {
2607        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2608        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
2609
2610        let model = HistGradientBoostingClassifier::<f64>::new()
2611            .with_n_estimators(10)
2612            .with_min_samples_leaf(1)
2613            .with_max_leaf_nodes(None)
2614            .with_max_depth(Some(3))
2615            .with_random_state(42);
2616        let fitted = model.fit_pipeline(&x, &y).unwrap();
2617        let preds = fitted.predict_pipeline(&x).unwrap();
2618        assert_eq!(preds.len(), 6);
2619    }
2620
2621    #[test]
2622    fn test_hgbc_f32_support() {
2623        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2624        let y = array![0, 0, 0, 1, 1, 1];
2625
2626        let model = HistGradientBoostingClassifier::<f32>::new()
2627            .with_n_estimators(10)
2628            .with_min_samples_leaf(1)
2629            .with_max_leaf_nodes(None)
2630            .with_max_depth(Some(3))
2631            .with_random_state(42);
2632        let fitted = model.fit(&x, &y).unwrap();
2633        let preds = fitted.predict(&x).unwrap();
2634        assert_eq!(preds.len(), 6);
2635    }
2636
2637    #[test]
2638    fn test_hgbc_default_trait() {
2639        let model = HistGradientBoostingClassifier::<f64>::default();
2640        assert_eq!(model.n_estimators, 100);
2641        assert!((model.learning_rate - 0.1).abs() < 1e-10);
2642        assert_eq!(model.max_bins, 255);
2643    }
2644
2645    #[test]
2646    fn test_hgbc_non_contiguous_labels() {
2647        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2648        let y = array![10, 10, 10, 20, 20, 20];
2649
2650        let model = HistGradientBoostingClassifier::<f64>::new()
2651            .with_n_estimators(20)
2652            .with_min_samples_leaf(1)
2653            .with_max_leaf_nodes(None)
2654            .with_max_depth(Some(3))
2655            .with_random_state(42);
2656        let fitted = model.fit(&x, &y).unwrap();
2657        let preds = fitted.predict(&x).unwrap();
2658
2659        assert_eq!(preds.len(), 6);
2660        for &p in &preds {
2661            assert!(p == 10 || p == 20);
2662        }
2663    }
2664
2665    #[test]
2666    fn test_hgbc_nan_handling() {
2667        // Classifier should handle NaN features.
2668        let x = Array2::from_shape_vec(
2669            (8, 2),
2670            vec![
2671                1.0,
2672                f64::NAN,
2673                2.0,
2674                3.0,
2675                f64::NAN,
2676                3.0,
2677                4.0,
2678                4.0,
2679                5.0,
2680                6.0,
2681                6.0,
2682                f64::NAN,
2683                7.0,
2684                8.0,
2685                f64::NAN,
2686                9.0,
2687            ],
2688        )
2689        .unwrap();
2690        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2691
2692        let model = HistGradientBoostingClassifier::<f64>::new()
2693            .with_n_estimators(20)
2694            .with_min_samples_leaf(1)
2695            .with_max_leaf_nodes(None)
2696            .with_max_depth(Some(3))
2697            .with_random_state(42);
2698        let fitted = model.fit(&x, &y).unwrap();
2699        let preds = fitted.predict(&x).unwrap();
2700        assert_eq!(preds.len(), 8);
2701        // All predictions should be valid class labels.
2702        for &p in &preds {
2703            assert!(p == 0 || p == 1);
2704        }
2705    }
2706
2707    // -- Comparison with standard GBM --
2708
2709    #[test]
2710    fn test_hist_vs_standard_gbm_similar_accuracy() {
2711        // Both models should achieve comparable accuracy on a simple task.
2712        use crate::GradientBoostingRegressor;
2713
2714        let x = Array2::from_shape_vec(
2715            (10, 1),
2716            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
2717        )
2718        .unwrap();
2719        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
2720
2721        let mse = |preds: &Array1<f64>, y: &Array1<f64>| -> f64 {
2722            preds
2723                .iter()
2724                .zip(y.iter())
2725                .map(|(p, t)| (p - t).powi(2))
2726                .sum::<f64>()
2727                / y.len() as f64
2728        };
2729
2730        // Standard GBM.
2731        let std_model = GradientBoostingRegressor::<f64>::new()
2732            .with_n_estimators(50)
2733            .with_learning_rate(0.1)
2734            .with_max_depth(Some(3))
2735            .with_random_state(42);
2736        let std_fitted = std_model.fit(&x, &y).unwrap();
2737        let std_preds = std_fitted.predict(&x).unwrap();
2738        let std_mse = mse(&std_preds, &y);
2739
2740        // Histogram GBM.
2741        let hist_model = HistGradientBoostingRegressor::<f64>::new()
2742            .with_n_estimators(50)
2743            .with_learning_rate(0.1)
2744            .with_min_samples_leaf(1)
2745            .with_max_leaf_nodes(None)
2746            .with_max_depth(Some(3))
2747            .with_random_state(42);
2748        let hist_fitted = hist_model.fit(&x, &y).unwrap();
2749        let hist_preds = hist_fitted.predict(&x).unwrap();
2750        let hist_mse = mse(&hist_preds, &y);
2751
2752        // Both should have low MSE on this simple task.
2753        assert!(std_mse < 1.0, "Standard GBM MSE too high: {std_mse}");
2754        assert!(hist_mse < 1.0, "Hist GBM MSE too high: {hist_mse}");
2755    }
2756}