Skip to main content

ferrolearn_tree/
hist_gradient_boosting.rs

1//! Histogram-based gradient boosting classifiers and regressors.
2//!
3//! This module provides [`HistGradientBoostingClassifier`] and
4//! [`HistGradientBoostingRegressor`], which implement histogram-based gradient
5//! boosting trees inspired by LightGBM / scikit-learn's HistGradientBoosting*.
6//!
7//! # Key design choices
8//!
9//! - **Feature binning**: continuous features are discretised into up to
10//!   `max_bins` (default 256) bins using quantile-based bin edges. NaN values
11//!   are assigned a dedicated bin.
12//! - **Histogram-based split finding**: at each node, gradient/hessian sums are
13//!   accumulated into per-bin histograms, making split finding O(n_bins) instead
14//!   of O(n log n).
15//! - **Subtraction trick**: the child histogram with more samples is computed by
16//!   subtracting the smaller child's histogram from the parent, halving the work.
17//! - **Missing value support**: NaN values are routed to the bin that yields the
18//!   best split, enabling native handling of missing data.
19//!
20//! # Regression Losses
21//!
22//! - **`LeastSquares`**: mean squared error; gradient = `y - F(x)`, hessian = 1.
23//! - **`LeastAbsoluteDeviation`**: gradient = `sign(y - F(x))`, hessian = 1.
24//!
25//! # Classification Loss
26//!
27//! - **`LogLoss`**: binary and multiclass logistic loss (one-vs-rest via softmax).
28//!
29//! # Examples
30//!
31//! ```
32//! use ferrolearn_tree::HistGradientBoostingRegressor;
33//! use ferrolearn_core::{Fit, Predict};
34//! use ndarray::{array, Array1, Array2};
35//!
36//! let x = Array2::from_shape_vec((8, 1), vec![
37//!     1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
38//! ]).unwrap();
39//! let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
40//!
41//! let model = HistGradientBoostingRegressor::<f64>::new()
42//!     .with_n_estimators(50)
43//!     .with_learning_rate(0.1)
44//!     .with_random_state(42);
45//! let fitted = model.fit(&x, &y).unwrap();
46//! let preds = fitted.predict(&x).unwrap();
47//! assert_eq!(preds.len(), 8);
48//! ```
49
50use ferrolearn_core::error::FerroError;
51use ferrolearn_core::introspection::{HasClasses, HasFeatureImportances};
52use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
53use ferrolearn_core::traits::{Fit, Predict};
54use ndarray::{Array1, Array2};
55use num_traits::{Float, FromPrimitive, ToPrimitive};
56use serde::{Deserialize, Serialize};
57
58// ---------------------------------------------------------------------------
59// Loss enums
60// ---------------------------------------------------------------------------
61
62/// Loss function for histogram-based gradient boosting regression.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum HistRegressionLoss {
65    /// Least squares (L2) loss.
66    LeastSquares,
67    /// Least absolute deviation (L1) loss.
68    LeastAbsoluteDeviation,
69}
70
71/// Loss function for histogram-based gradient boosting classification.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum HistClassificationLoss {
74    /// Log-loss (logistic / cross-entropy) for binary and multiclass.
75    LogLoss,
76}
77
78// ---------------------------------------------------------------------------
79// Internal: histogram tree node
80// ---------------------------------------------------------------------------
81
82/// A node in a histogram-based gradient boosting tree.
83///
84/// Unlike [`crate::decision_tree::Node`], thresholds are stored as bin indices
85/// rather than raw feature values. Prediction at a split checks whether the
86/// sample's bin index for the feature is `<= threshold_bin`.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum HistNode<F> {
89    /// An internal split node.
90    Split {
91        /// Feature index used for the split.
92        feature: usize,
93        /// Bin index threshold; samples with `bin[feature] <= threshold_bin`
94        /// go left.
95        threshold_bin: u16,
96        /// Whether NaN values should go left (`true`) or right (`false`).
97        nan_goes_left: bool,
98        /// Index of the left child node in the flat vec.
99        left: usize,
100        /// Index of the right child node in the flat vec.
101        right: usize,
102        /// Weighted gain from this split (for feature importance).
103        gain: F,
104        /// Number of samples that reached this node during training.
105        n_samples: usize,
106    },
107    /// A leaf node that stores a prediction value.
108    Leaf {
109        /// Predicted value (raw leaf output, e.g. gradient step).
110        value: F,
111        /// Number of samples that reached this node during training.
112        n_samples: usize,
113    },
114}
115
116// ---------------------------------------------------------------------------
117// Internal: binning
118// ---------------------------------------------------------------------------
119
120/// The special bin index reserved for NaN / missing values.
121const NAN_BIN: u16 = u16::MAX;
122
123/// Bin edges for a single feature, plus the number of non-NaN bins.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125struct FeatureBinInfo<F> {
126    /// Sorted bin edges (upper thresholds). `edges[b]` is the upper bound
127    /// for bin `b`. The last bin captures everything above `edges[len-2]`.
128    edges: Vec<F>,
129    /// Number of non-NaN bins for this feature (at most `max_bins`).
130    n_bins: u16,
131    /// Whether any NaN was observed for this feature during binning.
132    has_nan: bool,
133}
134
135/// Compute quantile-based bin edges for every feature.
136fn compute_bin_edges<F: Float>(x: &Array2<F>, max_bins: u16) -> Vec<FeatureBinInfo<F>> {
137    let n_features = x.ncols();
138    let n_samples = x.nrows();
139    let mut infos = Vec::with_capacity(n_features);
140
141    for j in 0..n_features {
142        let col = x.column(j);
143        // Separate non-NaN values and sort them.
144        let mut vals: Vec<F> = col.iter().copied().filter(|v| !v.is_nan()).collect();
145        vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
146
147        let has_nan = vals.len() < n_samples;
148
149        if vals.is_empty() {
150            // All NaN.
151            infos.push(FeatureBinInfo {
152                edges: Vec::new(),
153                n_bins: 0,
154                has_nan: true,
155            });
156            continue;
157        }
158
159        // Deduplicate to find unique values.
160        let mut unique: Vec<F> = Vec::new();
161        for &v in &vals {
162            if unique.is_empty() || (v - *unique.last().unwrap()).abs() > F::epsilon() {
163                unique.push(v);
164            }
165        }
166
167        let n_unique = unique.len();
168        let actual_bins = (max_bins as usize).min(n_unique);
169
170        if actual_bins <= 1 {
171            // Only one unique value — one bin, edge is that value.
172            infos.push(FeatureBinInfo {
173                edges: vec![unique[0]],
174                n_bins: 1,
175                has_nan,
176            });
177            continue;
178        }
179
180        // Compute quantile-based bin edges.
181        let mut edges = Vec::with_capacity(actual_bins);
182        for b in 1..actual_bins {
183            let frac = b as f64 / actual_bins as f64;
184            let idx_f = frac * (n_unique as f64 - 1.0);
185            let lo = idx_f.floor() as usize;
186            let hi = (lo + 1).min(n_unique - 1);
187            let t = F::from(idx_f - lo as f64).unwrap();
188            let edge = unique[lo] * (F::one() - t) + unique[hi] * t;
189            // Avoid duplicate edges.
190            if edges.is_empty() || (edge - *edges.last().unwrap()).abs() > F::epsilon() {
191                edges.push(edge);
192            }
193        }
194        // Add a final edge for the upper bound (max value).
195        let last = *unique.last().unwrap();
196        if edges.is_empty() || (last - *edges.last().unwrap()).abs() > F::epsilon() {
197            edges.push(last);
198        }
199
200        let n_bins = edges.len() as u16;
201        infos.push(FeatureBinInfo {
202            edges,
203            n_bins,
204            has_nan,
205        });
206    }
207
208    infos
209}
210
211/// Map a single feature value to its bin index given bin edges.
212#[inline]
213fn map_to_bin<F: Float>(value: F, info: &FeatureBinInfo<F>) -> u16 {
214    if value.is_nan() {
215        return NAN_BIN;
216    }
217    if info.n_bins == 0 {
218        return NAN_BIN;
219    }
220    // Binary search: find the first edge >= value.
221    let edges = &info.edges;
222    let mut lo: usize = 0;
223    let mut hi: usize = edges.len();
224    while lo < hi {
225        let mid = lo + (hi - lo) / 2;
226        if edges[mid] < value {
227            lo = mid + 1;
228        } else {
229            hi = mid;
230        }
231    }
232    if lo >= edges.len() {
233        (info.n_bins - 1).min(edges.len() as u16 - 1)
234    } else {
235        lo as u16
236    }
237}
238
239/// Bin all samples for all features.
240fn bin_data<F: Float>(x: &Array2<F>, bin_infos: &[FeatureBinInfo<F>]) -> Vec<Vec<u16>> {
241    let (n_samples, n_features) = x.dim();
242    // bins[i][j] = bin index for sample i, feature j.
243    let mut bins = vec![vec![0u16; n_features]; n_samples];
244    for j in 0..n_features {
245        for i in 0..n_samples {
246            bins[i][j] = map_to_bin(x[[i, j]], &bin_infos[j]);
247        }
248    }
249    bins
250}
251
252// ---------------------------------------------------------------------------
253// Internal: histogram-based tree building
254// ---------------------------------------------------------------------------
255
256/// Gradient/hessian accumulator for a single bin.
257#[derive(Debug, Clone, Copy, Default)]
258struct BinEntry<F> {
259    grad_sum: F,
260    hess_sum: F,
261    count: usize,
262}
263
264/// A histogram for one feature: one entry per bin, plus an optional NaN entry.
265struct FeatureHistogram<F> {
266    bins: Vec<BinEntry<F>>,
267    nan_entry: BinEntry<F>,
268}
269
270impl<F: Float> FeatureHistogram<F> {
271    fn new(n_bins: u16) -> Self {
272        Self {
273            bins: vec![
274                BinEntry {
275                    grad_sum: F::zero(),
276                    hess_sum: F::zero(),
277                    count: 0,
278                };
279                n_bins as usize
280            ],
281            nan_entry: BinEntry {
282                grad_sum: F::zero(),
283                hess_sum: F::zero(),
284                count: 0,
285            },
286        }
287    }
288}
289
290/// Build histograms for all features from scratch over the given sample indices.
291fn build_histograms<F: Float>(
292    bin_data: &[Vec<u16>],
293    gradients: &[F],
294    hessians: &[F],
295    sample_indices: &[usize],
296    bin_infos: &[FeatureBinInfo<F>],
297) -> Vec<FeatureHistogram<F>> {
298    let n_features = bin_infos.len();
299    let mut histograms: Vec<FeatureHistogram<F>> = bin_infos
300        .iter()
301        .map(|info| FeatureHistogram::new(info.n_bins))
302        .collect();
303
304    for &i in sample_indices {
305        let g = gradients[i];
306        let h = hessians[i];
307        for j in 0..n_features {
308            let b = bin_data[i][j];
309            if b == NAN_BIN {
310                histograms[j].nan_entry.grad_sum = histograms[j].nan_entry.grad_sum + g;
311                histograms[j].nan_entry.hess_sum = histograms[j].nan_entry.hess_sum + h;
312                histograms[j].nan_entry.count += 1;
313            } else {
314                let entry = &mut histograms[j].bins[b as usize];
315                entry.grad_sum = entry.grad_sum + g;
316                entry.hess_sum = entry.hess_sum + h;
317                entry.count += 1;
318            }
319        }
320    }
321
322    histograms
323}
324
325/// Subtraction trick: compute child histogram from parent and sibling.
326fn subtract_histograms<F: Float>(
327    parent: &[FeatureHistogram<F>],
328    sibling: &[FeatureHistogram<F>],
329    bin_infos: &[FeatureBinInfo<F>],
330) -> Vec<FeatureHistogram<F>> {
331    let n_features = bin_infos.len();
332    let mut result: Vec<FeatureHistogram<F>> = bin_infos
333        .iter()
334        .map(|info| FeatureHistogram::new(info.n_bins))
335        .collect();
336
337    for j in 0..n_features {
338        let n_bins = bin_infos[j].n_bins as usize;
339        for b in 0..n_bins {
340            result[j].bins[b].grad_sum = parent[j].bins[b].grad_sum - sibling[j].bins[b].grad_sum;
341            result[j].bins[b].hess_sum = parent[j].bins[b].hess_sum - sibling[j].bins[b].hess_sum;
342            // count is usize so we need saturating_sub for safety
343            result[j].bins[b].count = parent[j].bins[b]
344                .count
345                .saturating_sub(sibling[j].bins[b].count);
346        }
347        // NaN bin
348        result[j].nan_entry.grad_sum = parent[j].nan_entry.grad_sum - sibling[j].nan_entry.grad_sum;
349        result[j].nan_entry.hess_sum = parent[j].nan_entry.hess_sum - sibling[j].nan_entry.hess_sum;
350        result[j].nan_entry.count = parent[j]
351            .nan_entry
352            .count
353            .saturating_sub(sibling[j].nan_entry.count);
354    }
355
356    result
357}
358
359/// Result of finding the best split for a node.
360struct SplitCandidate<F> {
361    feature: usize,
362    threshold_bin: u16,
363    gain: F,
364    nan_goes_left: bool,
365}
366
367/// Find the best split across all features from histograms.
368///
369/// Uses the standard gain formula:
370///   gain = (G_L^2 / (H_L + lambda)) + (G_R^2 / (H_R + lambda)) - (G_parent^2 / (H_parent + lambda))
371fn find_best_split_from_histograms<F: Float>(
372    histograms: &[FeatureHistogram<F>],
373    bin_infos: &[FeatureBinInfo<F>],
374    total_grad: F,
375    total_hess: F,
376    total_count: usize,
377    l2_regularization: F,
378    min_samples_leaf: usize,
379) -> Option<SplitCandidate<F>> {
380    let n_features = bin_infos.len();
381    let parent_gain = total_grad * total_grad / (total_hess + l2_regularization);
382
383    let mut best: Option<SplitCandidate<F>> = None;
384
385    for j in 0..n_features {
386        let n_bins = bin_infos[j].n_bins as usize;
387        if n_bins <= 1 {
388            continue;
389        }
390        let nan = &histograms[j].nan_entry;
391
392        // Try scanning left-to-right through bins.
393        // For each split position b, left contains bins 0..=b, right contains b+1..n_bins-1.
394        // NaN samples can go either left or right; we try both.
395        let mut left_grad = F::zero();
396        let mut left_hess = F::zero();
397        let mut left_count: usize = 0;
398
399        for b in 0..(n_bins - 1) {
400            let entry = &histograms[j].bins[b];
401            left_grad = left_grad + entry.grad_sum;
402            left_hess = left_hess + entry.hess_sum;
403            left_count += entry.count;
404
405            let right_grad_no_nan = total_grad - left_grad - nan.grad_sum;
406            let right_hess_no_nan = total_hess - left_hess - nan.hess_sum;
407            let right_count_no_nan = total_count
408                .saturating_sub(left_count)
409                .saturating_sub(nan.count);
410
411            // Try NaN goes left.
412            {
413                let lg = left_grad + nan.grad_sum;
414                let lh = left_hess + nan.hess_sum;
415                let lc = left_count + nan.count;
416                let rg = right_grad_no_nan;
417                let rh = right_hess_no_nan;
418                let rc = right_count_no_nan;
419
420                if lc >= min_samples_leaf && rc >= min_samples_leaf {
421                    let gain = lg * lg / (lh + l2_regularization)
422                        + rg * rg / (rh + l2_regularization)
423                        - parent_gain;
424                    if gain > F::zero() {
425                        let better = match &best {
426                            None => true,
427                            Some(curr) => gain > curr.gain,
428                        };
429                        if better {
430                            best = Some(SplitCandidate {
431                                feature: j,
432                                threshold_bin: b as u16,
433                                gain,
434                                nan_goes_left: true,
435                            });
436                        }
437                    }
438                }
439            }
440
441            // Try NaN goes right.
442            {
443                let lg = left_grad;
444                let lh = left_hess;
445                let lc = left_count;
446                let rg = right_grad_no_nan + nan.grad_sum;
447                let rh = right_hess_no_nan + nan.hess_sum;
448                let rc = right_count_no_nan + nan.count;
449
450                if lc >= min_samples_leaf && rc >= min_samples_leaf {
451                    let gain = lg * lg / (lh + l2_regularization)
452                        + rg * rg / (rh + l2_regularization)
453                        - parent_gain;
454                    if gain > F::zero() {
455                        let better = match &best {
456                            None => true,
457                            Some(curr) => gain > curr.gain,
458                        };
459                        if better {
460                            best = Some(SplitCandidate {
461                                feature: j,
462                                threshold_bin: b as u16,
463                                gain,
464                                nan_goes_left: false,
465                            });
466                        }
467                    }
468                }
469            }
470        }
471    }
472
473    best
474}
475
476/// Compute leaf value: -G / (H + lambda).
477#[inline]
478fn compute_leaf_value<F: Float>(grad_sum: F, hess_sum: F, l2_reg: F) -> F {
479    if hess_sum.abs() < F::epsilon() {
480        F::zero()
481    } else {
482        -grad_sum / (hess_sum + l2_reg)
483    }
484}
485
486/// Parameters for histogram tree building.
487struct HistTreeParams<F> {
488    max_depth: Option<usize>,
489    min_samples_leaf: usize,
490    max_leaf_nodes: Option<usize>,
491    l2_regularization: F,
492}
493
494/// Build a single histogram-based regression tree.
495///
496/// Returns a vector of `HistNode`s (flat representation).
497fn build_hist_tree<F: Float>(
498    binned: &[Vec<u16>],
499    gradients: &[F],
500    hessians: &[F],
501    sample_indices: &[usize],
502    bin_infos: &[FeatureBinInfo<F>],
503    params: &HistTreeParams<F>,
504) -> Vec<HistNode<F>> {
505    let n_features = bin_infos.len();
506    let _ = n_features; // used implicitly through bin_infos
507
508    if sample_indices.is_empty() {
509        return vec![HistNode::Leaf {
510            value: F::zero(),
511            n_samples: 0,
512        }];
513    }
514
515    // Use a work queue approach for best-first (leaf-wise) or depth-first growth.
516    // If max_leaf_nodes is set, use best-first; otherwise depth-first.
517    if params.max_leaf_nodes.is_some() {
518        build_hist_tree_best_first(
519            binned,
520            gradients,
521            hessians,
522            sample_indices,
523            bin_infos,
524            params,
525        )
526    } else {
527        let mut nodes = Vec::new();
528        let hist = build_histograms(binned, gradients, hessians, sample_indices, bin_infos);
529        build_hist_tree_recursive(
530            binned,
531            gradients,
532            hessians,
533            sample_indices,
534            bin_infos,
535            params,
536            &hist,
537            0,
538            &mut nodes,
539        );
540        nodes
541    }
542}
543
544/// Recursive depth-first histogram tree building.
545#[allow(clippy::too_many_arguments)]
546fn build_hist_tree_recursive<F: Float>(
547    binned: &[Vec<u16>],
548    gradients: &[F],
549    hessians: &[F],
550    sample_indices: &[usize],
551    bin_infos: &[FeatureBinInfo<F>],
552    params: &HistTreeParams<F>,
553    histograms: &[FeatureHistogram<F>],
554    depth: usize,
555    nodes: &mut Vec<HistNode<F>>,
556) -> usize {
557    let n = sample_indices.len();
558    let grad_sum: F = sample_indices
559        .iter()
560        .map(|&i| gradients[i])
561        .fold(F::zero(), |a, b| a + b);
562    let hess_sum: F = sample_indices
563        .iter()
564        .map(|&i| hessians[i])
565        .fold(F::zero(), |a, b| a + b);
566
567    // Check stopping conditions.
568    let at_max_depth = params.max_depth.is_some_and(|d| depth >= d);
569    let too_few = n < 2 * params.min_samples_leaf;
570
571    if at_max_depth || too_few {
572        let idx = nodes.len();
573        nodes.push(HistNode::Leaf {
574            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
575            n_samples: n,
576        });
577        return idx;
578    }
579
580    // Find best split.
581    let split = find_best_split_from_histograms(
582        histograms,
583        bin_infos,
584        grad_sum,
585        hess_sum,
586        n,
587        params.l2_regularization,
588        params.min_samples_leaf,
589    );
590
591    let split = match split {
592        Some(s) => s,
593        None => {
594            let idx = nodes.len();
595            nodes.push(HistNode::Leaf {
596                value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
597                n_samples: n,
598            });
599            return idx;
600        }
601    };
602
603    // Partition samples into left and right.
604    let (left_indices, right_indices): (Vec<usize>, Vec<usize>) =
605        sample_indices.iter().partition(|&&i| {
606            let b = binned[i][split.feature];
607            if b == NAN_BIN {
608                split.nan_goes_left
609            } else {
610                b <= split.threshold_bin
611            }
612        });
613
614    if left_indices.is_empty() || right_indices.is_empty() {
615        let idx = nodes.len();
616        nodes.push(HistNode::Leaf {
617            value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
618            n_samples: n,
619        });
620        return idx;
621    }
622
623    // Build histograms for the smaller child, then use subtraction trick.
624    let (small_indices, _large_indices, small_is_left) =
625        if left_indices.len() <= right_indices.len() {
626            (&left_indices, &right_indices, true)
627        } else {
628            (&right_indices, &left_indices, false)
629        };
630
631    let small_hist = build_histograms(binned, gradients, hessians, small_indices, bin_infos);
632    let large_hist = subtract_histograms(histograms, &small_hist, bin_infos);
633
634    let (left_hist, right_hist) = if small_is_left {
635        (small_hist, large_hist)
636    } else {
637        (large_hist, small_hist)
638    };
639
640    // Reserve a placeholder for this split node.
641    let node_idx = nodes.len();
642    nodes.push(HistNode::Leaf {
643        value: F::zero(),
644        n_samples: 0,
645    }); // placeholder
646
647    // Recurse.
648    let left_idx = build_hist_tree_recursive(
649        binned,
650        gradients,
651        hessians,
652        &left_indices,
653        bin_infos,
654        params,
655        &left_hist,
656        depth + 1,
657        nodes,
658    );
659    let right_idx = build_hist_tree_recursive(
660        binned,
661        gradients,
662        hessians,
663        &right_indices,
664        bin_infos,
665        params,
666        &right_hist,
667        depth + 1,
668        nodes,
669    );
670
671    nodes[node_idx] = HistNode::Split {
672        feature: split.feature,
673        threshold_bin: split.threshold_bin,
674        nan_goes_left: split.nan_goes_left,
675        left: left_idx,
676        right: right_idx,
677        gain: split.gain,
678        n_samples: n,
679    };
680
681    node_idx
682}
683
684/// Entry in the best-first priority queue.
685struct SplitTask {
686    /// Indices of samples at this node.
687    sample_indices: Vec<usize>,
688    /// The node index in the flat node vec (a leaf placeholder).
689    node_idx: usize,
690    /// Depth of this node.
691    depth: usize,
692    /// The gain of the best split at this node.
693    gain: f64,
694    /// Feature of the best split.
695    feature: usize,
696    /// Bin threshold of the best split.
697    threshold_bin: u16,
698    /// Whether NaN goes left.
699    nan_goes_left: bool,
700}
701
702/// Build a histogram tree using best-first (leaf-wise) growth with max_leaf_nodes.
703fn build_hist_tree_best_first<F: Float>(
704    binned: &[Vec<u16>],
705    gradients: &[F],
706    hessians: &[F],
707    sample_indices: &[usize],
708    bin_infos: &[FeatureBinInfo<F>],
709    params: &HistTreeParams<F>,
710) -> Vec<HistNode<F>> {
711    let max_leaves = params.max_leaf_nodes.unwrap_or(usize::MAX);
712    let mut nodes: Vec<HistNode<F>> = Vec::new();
713
714    // Root node.
715    let n = sample_indices.len();
716    let grad_sum: F = sample_indices
717        .iter()
718        .map(|&i| gradients[i])
719        .fold(F::zero(), |a, b| a + b);
720    let hess_sum: F = sample_indices
721        .iter()
722        .map(|&i| hessians[i])
723        .fold(F::zero(), |a, b| a + b);
724
725    let root_idx = nodes.len();
726    nodes.push(HistNode::Leaf {
727        value: compute_leaf_value(grad_sum, hess_sum, params.l2_regularization),
728        n_samples: n,
729    });
730
731    let root_hist = build_histograms(binned, gradients, hessians, sample_indices, bin_infos);
732    let root_split = find_best_split_from_histograms(
733        &root_hist,
734        bin_infos,
735        grad_sum,
736        hess_sum,
737        n,
738        params.l2_regularization,
739        params.min_samples_leaf,
740    );
741
742    let mut pending: Vec<(SplitTask, Vec<FeatureHistogram<F>>)> = Vec::new();
743    let mut n_leaves: usize = 1;
744
745    if let Some(split) = root_split {
746        let at_max_depth = params.max_depth.is_some_and(|d| d == 0);
747        if !at_max_depth {
748            pending.push((
749                SplitTask {
750                    sample_indices: sample_indices.to_vec(),
751                    node_idx: root_idx,
752                    depth: 0,
753                    gain: split.gain.to_f64().unwrap_or(0.0),
754                    feature: split.feature,
755                    threshold_bin: split.threshold_bin,
756                    nan_goes_left: split.nan_goes_left,
757                },
758                root_hist,
759            ));
760        }
761    }
762
763    while !pending.is_empty() && n_leaves < max_leaves {
764        // Pick the task with the highest gain.
765        let best_idx = pending
766            .iter()
767            .enumerate()
768            .max_by(|(_, a), (_, b)| {
769                a.0.gain
770                    .partial_cmp(&b.0.gain)
771                    .unwrap_or(std::cmp::Ordering::Equal)
772            })
773            .map(|(i, _)| i)
774            .unwrap();
775
776        let (task, parent_hist) = pending.swap_remove(best_idx);
777
778        // Partition.
779        let (left_indices, right_indices): (Vec<usize>, Vec<usize>) =
780            task.sample_indices.iter().partition(|&&i| {
781                let b = binned[i][task.feature];
782                if b == NAN_BIN {
783                    task.nan_goes_left
784                } else {
785                    b <= task.threshold_bin
786                }
787            });
788
789        if left_indices.is_empty() || right_indices.is_empty() {
790            continue;
791        }
792
793        // Build histograms using subtraction trick.
794        let (small_indices, _large_indices, small_is_left) =
795            if left_indices.len() <= right_indices.len() {
796                (&left_indices, &right_indices, true)
797            } else {
798                (&right_indices, &left_indices, false)
799            };
800
801        let small_hist = build_histograms(binned, gradients, hessians, small_indices, bin_infos);
802        let large_hist = subtract_histograms(&parent_hist, &small_hist, bin_infos);
803
804        let (left_hist, right_hist) = if small_is_left {
805            (small_hist, large_hist)
806        } else {
807            (large_hist, small_hist)
808        };
809
810        // Create left and right leaf nodes.
811        let left_grad: F = left_indices
812            .iter()
813            .map(|&i| gradients[i])
814            .fold(F::zero(), |a, b| a + b);
815        let left_hess: F = left_indices
816            .iter()
817            .map(|&i| hessians[i])
818            .fold(F::zero(), |a, b| a + b);
819        let right_grad: F = right_indices
820            .iter()
821            .map(|&i| gradients[i])
822            .fold(F::zero(), |a, b| a + b);
823        let right_hess: F = right_indices
824            .iter()
825            .map(|&i| hessians[i])
826            .fold(F::zero(), |a, b| a + b);
827
828        let left_idx = nodes.len();
829        nodes.push(HistNode::Leaf {
830            value: compute_leaf_value(left_grad, left_hess, params.l2_regularization),
831            n_samples: left_indices.len(),
832        });
833        let right_idx = nodes.len();
834        nodes.push(HistNode::Leaf {
835            value: compute_leaf_value(right_grad, right_hess, params.l2_regularization),
836            n_samples: right_indices.len(),
837        });
838
839        // Convert the parent leaf placeholder into a split node.
840        nodes[task.node_idx] = HistNode::Split {
841            feature: task.feature,
842            threshold_bin: task.threshold_bin,
843            nan_goes_left: task.nan_goes_left,
844            left: left_idx,
845            right: right_idx,
846            gain: F::from(task.gain).unwrap(),
847            n_samples: task.sample_indices.len(),
848        };
849
850        // One leaf became two, so net +1 leaf.
851        n_leaves += 1;
852
853        let child_depth = task.depth + 1;
854        let at_max_depth = params.max_depth.is_some_and(|d| child_depth >= d);
855
856        if !at_max_depth && n_leaves < max_leaves {
857            // Try to split left child.
858            if left_indices.len() >= 2 * params.min_samples_leaf {
859                let left_split = find_best_split_from_histograms(
860                    &left_hist,
861                    bin_infos,
862                    left_grad,
863                    left_hess,
864                    left_indices.len(),
865                    params.l2_regularization,
866                    params.min_samples_leaf,
867                );
868                if let Some(s) = left_split {
869                    pending.push((
870                        SplitTask {
871                            sample_indices: left_indices,
872                            node_idx: left_idx,
873                            depth: child_depth,
874                            gain: s.gain.to_f64().unwrap_or(0.0),
875                            feature: s.feature,
876                            threshold_bin: s.threshold_bin,
877                            nan_goes_left: s.nan_goes_left,
878                        },
879                        left_hist,
880                    ));
881                }
882            }
883
884            // Try to split right child.
885            if right_indices.len() >= 2 * params.min_samples_leaf {
886                let right_split = find_best_split_from_histograms(
887                    &right_hist,
888                    bin_infos,
889                    right_grad,
890                    right_hess,
891                    right_indices.len(),
892                    params.l2_regularization,
893                    params.min_samples_leaf,
894                );
895                if let Some(s) = right_split {
896                    pending.push((
897                        SplitTask {
898                            sample_indices: right_indices,
899                            node_idx: right_idx,
900                            depth: child_depth,
901                            gain: s.gain.to_f64().unwrap_or(0.0),
902                            feature: s.feature,
903                            threshold_bin: s.threshold_bin,
904                            nan_goes_left: s.nan_goes_left,
905                        },
906                        right_hist,
907                    ));
908                }
909            }
910        }
911    }
912
913    nodes
914}
915
916/// Traverse a histogram tree to find the leaf for a single binned sample.
917#[inline]
918fn traverse_hist_tree<F: Float>(nodes: &[HistNode<F>], sample_bins: &[u16]) -> usize {
919    let mut idx = 0;
920    loop {
921        match &nodes[idx] {
922            HistNode::Split {
923                feature,
924                threshold_bin,
925                nan_goes_left,
926                left,
927                right,
928                ..
929            } => {
930                let b = sample_bins[*feature];
931                if b == NAN_BIN {
932                    idx = if *nan_goes_left { *left } else { *right };
933                } else if b <= *threshold_bin {
934                    idx = *left;
935                } else {
936                    idx = *right;
937                }
938            }
939            HistNode::Leaf { .. } => return idx,
940        }
941    }
942}
943
944/// Compute feature importances from a histogram tree's gain values.
945fn compute_hist_feature_importances<F: Float>(
946    nodes: &[HistNode<F>],
947    n_features: usize,
948) -> Array1<F> {
949    let mut importances = Array1::zeros(n_features);
950    for node in nodes {
951        if let HistNode::Split { feature, gain, .. } = node {
952            importances[*feature] = importances[*feature] + *gain;
953        }
954    }
955    importances
956}
957
958// ---------------------------------------------------------------------------
959// Internal helpers
960// ---------------------------------------------------------------------------
961
962/// Sigmoid function: 1 / (1 + exp(-x)).
963fn sigmoid<F: Float>(x: F) -> F {
964    F::one() / (F::one() + (-x).exp())
965}
966
967/// Compute softmax probabilities for each class across all samples.
968///
969/// Returns `probs[k][i]` = probability of class k for sample i.
970fn softmax_matrix<F: Float>(
971    f_vals: &[Array1<F>],
972    n_samples: usize,
973    n_classes: usize,
974) -> Vec<Vec<F>> {
975    let mut probs: Vec<Vec<F>> = vec![vec![F::zero(); n_samples]; n_classes];
976    for i in 0..n_samples {
977        let max_val = (0..n_classes)
978            .map(|k| f_vals[k][i])
979            .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
980        let mut sum = F::zero();
981        let mut exps = vec![F::zero(); n_classes];
982        for k in 0..n_classes {
983            exps[k] = (f_vals[k][i] - max_val).exp();
984            sum = sum + exps[k];
985        }
986        let eps = F::from(1e-15).unwrap();
987        if sum < eps {
988            sum = eps;
989        }
990        for k in 0..n_classes {
991            probs[k][i] = exps[k] / sum;
992        }
993    }
994    probs
995}
996
997// ---------------------------------------------------------------------------
998// HistGradientBoostingRegressor
999// ---------------------------------------------------------------------------
1000
1001/// Histogram-based gradient boosting regressor.
1002///
1003/// Uses quantile-based feature binning and gradient/hessian histograms for
1004/// O(n_bins) split finding per node. This is significantly faster than the
1005/// standard [`GradientBoostingRegressor`](crate::GradientBoostingRegressor)
1006/// for larger datasets.
1007///
1008/// # Type Parameters
1009///
1010/// - `F`: The floating-point type (`f32` or `f64`).
1011#[derive(Debug, Clone, Serialize, Deserialize)]
1012pub struct HistGradientBoostingRegressor<F> {
1013    /// Number of boosting stages (trees).
1014    pub n_estimators: usize,
1015    /// Learning rate (shrinkage) applied to each tree's contribution.
1016    pub learning_rate: f64,
1017    /// Maximum depth of each tree.
1018    pub max_depth: Option<usize>,
1019    /// Minimum number of samples required in a leaf node.
1020    pub min_samples_leaf: usize,
1021    /// Maximum number of bins for feature discretisation (at most 256).
1022    pub max_bins: u16,
1023    /// L2 regularization term on weights.
1024    pub l2_regularization: f64,
1025    /// Maximum number of leaf nodes per tree (best-first growth).
1026    /// If `None`, depth-first growth is used with `max_depth`.
1027    pub max_leaf_nodes: Option<usize>,
1028    /// Loss function.
1029    pub loss: HistRegressionLoss,
1030    /// Random seed for reproducibility.
1031    pub random_state: Option<u64>,
1032    _marker: std::marker::PhantomData<F>,
1033}
1034
1035impl<F: Float> HistGradientBoostingRegressor<F> {
1036    /// Create a new `HistGradientBoostingRegressor` with default settings.
1037    ///
1038    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
1039    /// `max_depth = None`, `min_samples_leaf = 20`,
1040    /// `max_bins = 255`, `l2_regularization = 0.0`,
1041    /// `max_leaf_nodes = Some(31)`, `loss = LeastSquares`.
1042    #[must_use]
1043    pub fn new() -> Self {
1044        Self {
1045            n_estimators: 100,
1046            learning_rate: 0.1,
1047            max_depth: None,
1048            min_samples_leaf: 20,
1049            max_bins: 255,
1050            l2_regularization: 0.0,
1051            max_leaf_nodes: Some(31),
1052            loss: HistRegressionLoss::LeastSquares,
1053            random_state: None,
1054            _marker: std::marker::PhantomData,
1055        }
1056    }
1057
1058    /// Set the number of boosting stages.
1059    #[must_use]
1060    pub fn with_n_estimators(mut self, n: usize) -> Self {
1061        self.n_estimators = n;
1062        self
1063    }
1064
1065    /// Set the learning rate (shrinkage).
1066    #[must_use]
1067    pub fn with_learning_rate(mut self, lr: f64) -> Self {
1068        self.learning_rate = lr;
1069        self
1070    }
1071
1072    /// Set the maximum tree depth.
1073    #[must_use]
1074    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
1075        self.max_depth = d;
1076        self
1077    }
1078
1079    /// Set the minimum number of samples in a leaf.
1080    #[must_use]
1081    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
1082        self.min_samples_leaf = n;
1083        self
1084    }
1085
1086    /// Set the maximum number of bins for feature discretisation.
1087    #[must_use]
1088    pub fn with_max_bins(mut self, bins: u16) -> Self {
1089        self.max_bins = bins;
1090        self
1091    }
1092
1093    /// Set the L2 regularization term.
1094    #[must_use]
1095    pub fn with_l2_regularization(mut self, reg: f64) -> Self {
1096        self.l2_regularization = reg;
1097        self
1098    }
1099
1100    /// Set the maximum number of leaf nodes (best-first growth).
1101    #[must_use]
1102    pub fn with_max_leaf_nodes(mut self, n: Option<usize>) -> Self {
1103        self.max_leaf_nodes = n;
1104        self
1105    }
1106
1107    /// Set the loss function.
1108    #[must_use]
1109    pub fn with_loss(mut self, loss: HistRegressionLoss) -> Self {
1110        self.loss = loss;
1111        self
1112    }
1113
1114    /// Set the random seed for reproducibility.
1115    #[must_use]
1116    pub fn with_random_state(mut self, seed: u64) -> Self {
1117        self.random_state = Some(seed);
1118        self
1119    }
1120}
1121
1122impl<F: Float> Default for HistGradientBoostingRegressor<F> {
1123    fn default() -> Self {
1124        Self::new()
1125    }
1126}
1127
1128// ---------------------------------------------------------------------------
1129// FittedHistGradientBoostingRegressor
1130// ---------------------------------------------------------------------------
1131
1132/// A fitted histogram-based gradient boosting regressor.
1133///
1134/// Stores the binning information, initial prediction, and the sequence of
1135/// fitted histogram trees. Predictions are computed by binning the input
1136/// features and traversing each tree.
1137#[derive(Debug, Clone)]
1138pub struct FittedHistGradientBoostingRegressor<F> {
1139    /// Bin edge information for each feature.
1140    bin_infos: Vec<FeatureBinInfo<F>>,
1141    /// Initial prediction (baseline).
1142    init: F,
1143    /// Learning rate used during training.
1144    learning_rate: F,
1145    /// Sequence of fitted histogram trees.
1146    trees: Vec<Vec<HistNode<F>>>,
1147    /// Number of features.
1148    n_features: usize,
1149    /// Per-feature importance scores (normalised).
1150    feature_importances: Array1<F>,
1151}
1152
1153impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<F>>
1154    for HistGradientBoostingRegressor<F>
1155{
1156    type Fitted = FittedHistGradientBoostingRegressor<F>;
1157    type Error = FerroError;
1158
1159    /// Fit the histogram-based gradient boosting regressor.
1160    ///
1161    /// # Errors
1162    ///
1163    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
1164    /// numbers of samples.
1165    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
1166    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
1167    fn fit(
1168        &self,
1169        x: &Array2<F>,
1170        y: &Array1<F>,
1171    ) -> Result<FittedHistGradientBoostingRegressor<F>, FerroError> {
1172        let (n_samples, n_features) = x.dim();
1173
1174        // Validate inputs.
1175        if n_samples != y.len() {
1176            return Err(FerroError::ShapeMismatch {
1177                expected: vec![n_samples],
1178                actual: vec![y.len()],
1179                context: "y length must match number of samples in X".into(),
1180            });
1181        }
1182        if n_samples == 0 {
1183            return Err(FerroError::InsufficientSamples {
1184                required: 1,
1185                actual: 0,
1186                context: "HistGradientBoostingRegressor requires at least one sample".into(),
1187            });
1188        }
1189        if self.n_estimators == 0 {
1190            return Err(FerroError::InvalidParameter {
1191                name: "n_estimators".into(),
1192                reason: "must be at least 1".into(),
1193            });
1194        }
1195        if self.learning_rate <= 0.0 {
1196            return Err(FerroError::InvalidParameter {
1197                name: "learning_rate".into(),
1198                reason: "must be positive".into(),
1199            });
1200        }
1201        if self.max_bins < 2 {
1202            return Err(FerroError::InvalidParameter {
1203                name: "max_bins".into(),
1204                reason: "must be at least 2".into(),
1205            });
1206        }
1207
1208        let lr = F::from(self.learning_rate).unwrap();
1209        let l2_reg = F::from(self.l2_regularization).unwrap();
1210
1211        // Compute bin edges and bin the data.
1212        let bin_infos = compute_bin_edges(x, self.max_bins);
1213        let binned = bin_data(x, &bin_infos);
1214
1215        // Initial prediction.
1216        let init = match self.loss {
1217            HistRegressionLoss::LeastSquares => {
1218                let sum: F = y.iter().copied().fold(F::zero(), |a, b| a + b);
1219                sum / F::from(n_samples).unwrap()
1220            }
1221            HistRegressionLoss::LeastAbsoluteDeviation => median_f(y),
1222        };
1223
1224        let mut f_vals = Array1::from_elem(n_samples, init);
1225        let all_indices: Vec<usize> = (0..n_samples).collect();
1226
1227        let tree_params = HistTreeParams {
1228            max_depth: self.max_depth,
1229            min_samples_leaf: self.min_samples_leaf,
1230            max_leaf_nodes: self.max_leaf_nodes,
1231            l2_regularization: l2_reg,
1232        };
1233
1234        let mut trees = Vec::with_capacity(self.n_estimators);
1235
1236        for _ in 0..self.n_estimators {
1237            // Compute gradients and hessians.
1238            let (gradients, hessians) = match self.loss {
1239                HistRegressionLoss::LeastSquares => {
1240                    let grads: Vec<F> = (0..n_samples).map(|i| -(y[i] - f_vals[i])).collect();
1241                    let hess: Vec<F> = vec![F::one(); n_samples];
1242                    (grads, hess)
1243                }
1244                HistRegressionLoss::LeastAbsoluteDeviation => {
1245                    let grads: Vec<F> = (0..n_samples)
1246                        .map(|i| {
1247                            let diff = y[i] - f_vals[i];
1248                            if diff > F::zero() {
1249                                -F::one()
1250                            } else if diff < F::zero() {
1251                                F::one()
1252                            } else {
1253                                F::zero()
1254                            }
1255                        })
1256                        .collect();
1257                    let hess: Vec<F> = vec![F::one(); n_samples];
1258                    (grads, hess)
1259                }
1260            };
1261
1262            let tree = build_hist_tree(
1263                &binned,
1264                &gradients,
1265                &hessians,
1266                &all_indices,
1267                &bin_infos,
1268                &tree_params,
1269            );
1270
1271            // Update predictions.
1272            for i in 0..n_samples {
1273                let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1274                if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1275                    f_vals[i] = f_vals[i] + lr * value;
1276                }
1277            }
1278
1279            trees.push(tree);
1280        }
1281
1282        // Compute feature importances.
1283        let mut total_importances = Array1::<F>::zeros(n_features);
1284        for tree_nodes in &trees {
1285            total_importances =
1286                total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1287        }
1288        let imp_sum: F = total_importances
1289            .iter()
1290            .copied()
1291            .fold(F::zero(), |a, b| a + b);
1292        if imp_sum > F::zero() {
1293            total_importances.mapv_inplace(|v| v / imp_sum);
1294        }
1295
1296        Ok(FittedHistGradientBoostingRegressor {
1297            bin_infos,
1298            init,
1299            learning_rate: lr,
1300            trees,
1301            n_features,
1302            feature_importances: total_importances,
1303        })
1304    }
1305}
1306
1307impl<F: Float + Send + Sync + 'static> Predict<Array2<F>>
1308    for FittedHistGradientBoostingRegressor<F>
1309{
1310    type Output = Array1<F>;
1311    type Error = FerroError;
1312
1313    /// Predict target values.
1314    ///
1315    /// # Errors
1316    ///
1317    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
1318    /// not match the fitted model.
1319    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1320        if x.ncols() != self.n_features {
1321            return Err(FerroError::ShapeMismatch {
1322                expected: vec![self.n_features],
1323                actual: vec![x.ncols()],
1324                context: "number of features must match fitted model".into(),
1325            });
1326        }
1327
1328        let n_samples = x.nrows();
1329        let binned = bin_data(x, &self.bin_infos);
1330        let mut predictions = Array1::from_elem(n_samples, self.init);
1331
1332        for i in 0..n_samples {
1333            for tree_nodes in &self.trees {
1334                let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1335                if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1336                    predictions[i] = predictions[i] + self.learning_rate * value;
1337                }
1338            }
1339        }
1340
1341        Ok(predictions)
1342    }
1343}
1344
1345impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1346    for FittedHistGradientBoostingRegressor<F>
1347{
1348    fn feature_importances(&self) -> &Array1<F> {
1349        &self.feature_importances
1350    }
1351}
1352
1353// Pipeline integration.
1354impl<F: Float + Send + Sync + 'static> PipelineEstimator<F> for HistGradientBoostingRegressor<F> {
1355    fn fit_pipeline(
1356        &self,
1357        x: &Array2<F>,
1358        y: &Array1<F>,
1359    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1360        let fitted = self.fit(x, y)?;
1361        Ok(Box::new(fitted))
1362    }
1363}
1364
1365impl<F: Float + Send + Sync + 'static> FittedPipelineEstimator<F>
1366    for FittedHistGradientBoostingRegressor<F>
1367{
1368    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1369        self.predict(x)
1370    }
1371}
1372
1373// ---------------------------------------------------------------------------
1374// HistGradientBoostingClassifier
1375// ---------------------------------------------------------------------------
1376
1377/// Histogram-based gradient boosting classifier.
1378///
1379/// For binary classification a single model is trained on log-odds residuals.
1380/// For multiclass (*K* classes), *K* histogram trees are built per boosting
1381/// round (one-vs-rest in probability space via softmax).
1382///
1383/// # Type Parameters
1384///
1385/// - `F`: The floating-point type (`f32` or `f64`).
1386#[derive(Debug, Clone, Serialize, Deserialize)]
1387pub struct HistGradientBoostingClassifier<F> {
1388    /// Number of boosting stages.
1389    pub n_estimators: usize,
1390    /// Learning rate (shrinkage).
1391    pub learning_rate: f64,
1392    /// Maximum depth of each tree.
1393    pub max_depth: Option<usize>,
1394    /// Minimum number of samples required in a leaf node.
1395    pub min_samples_leaf: usize,
1396    /// Maximum number of bins for feature discretisation (at most 256).
1397    pub max_bins: u16,
1398    /// L2 regularization term on weights.
1399    pub l2_regularization: f64,
1400    /// Maximum number of leaf nodes per tree (best-first growth).
1401    pub max_leaf_nodes: Option<usize>,
1402    /// Classification loss function.
1403    pub loss: HistClassificationLoss,
1404    /// Random seed for reproducibility (reserved for future subsampling).
1405    pub random_state: Option<u64>,
1406    _marker: std::marker::PhantomData<F>,
1407}
1408
1409impl<F: Float> HistGradientBoostingClassifier<F> {
1410    /// Create a new `HistGradientBoostingClassifier` with default settings.
1411    ///
1412    /// Defaults: `n_estimators = 100`, `learning_rate = 0.1`,
1413    /// `max_depth = None`, `min_samples_leaf = 20`,
1414    /// `max_bins = 255`, `l2_regularization = 0.0`,
1415    /// `max_leaf_nodes = Some(31)`, `loss = LogLoss`.
1416    #[must_use]
1417    pub fn new() -> Self {
1418        Self {
1419            n_estimators: 100,
1420            learning_rate: 0.1,
1421            max_depth: None,
1422            min_samples_leaf: 20,
1423            max_bins: 255,
1424            l2_regularization: 0.0,
1425            max_leaf_nodes: Some(31),
1426            loss: HistClassificationLoss::LogLoss,
1427            random_state: None,
1428            _marker: std::marker::PhantomData,
1429        }
1430    }
1431
1432    /// Set the number of boosting stages.
1433    #[must_use]
1434    pub fn with_n_estimators(mut self, n: usize) -> Self {
1435        self.n_estimators = n;
1436        self
1437    }
1438
1439    /// Set the learning rate (shrinkage).
1440    #[must_use]
1441    pub fn with_learning_rate(mut self, lr: f64) -> Self {
1442        self.learning_rate = lr;
1443        self
1444    }
1445
1446    /// Set the maximum tree depth.
1447    #[must_use]
1448    pub fn with_max_depth(mut self, d: Option<usize>) -> Self {
1449        self.max_depth = d;
1450        self
1451    }
1452
1453    /// Set the minimum number of samples in a leaf.
1454    #[must_use]
1455    pub fn with_min_samples_leaf(mut self, n: usize) -> Self {
1456        self.min_samples_leaf = n;
1457        self
1458    }
1459
1460    /// Set the maximum number of bins for feature discretisation.
1461    #[must_use]
1462    pub fn with_max_bins(mut self, bins: u16) -> Self {
1463        self.max_bins = bins;
1464        self
1465    }
1466
1467    /// Set the L2 regularization term.
1468    #[must_use]
1469    pub fn with_l2_regularization(mut self, reg: f64) -> Self {
1470        self.l2_regularization = reg;
1471        self
1472    }
1473
1474    /// Set the maximum number of leaf nodes (best-first growth).
1475    #[must_use]
1476    pub fn with_max_leaf_nodes(mut self, n: Option<usize>) -> Self {
1477        self.max_leaf_nodes = n;
1478        self
1479    }
1480
1481    /// Set the random seed for reproducibility.
1482    #[must_use]
1483    pub fn with_random_state(mut self, seed: u64) -> Self {
1484        self.random_state = Some(seed);
1485        self
1486    }
1487}
1488
1489impl<F: Float> Default for HistGradientBoostingClassifier<F> {
1490    fn default() -> Self {
1491        Self::new()
1492    }
1493}
1494
1495// ---------------------------------------------------------------------------
1496// FittedHistGradientBoostingClassifier
1497// ---------------------------------------------------------------------------
1498
1499/// A fitted histogram-based gradient boosting classifier.
1500///
1501/// For binary classification, stores a single sequence of trees predicting log-odds.
1502/// For multiclass, stores `K` sequences of trees (one per class).
1503#[derive(Debug, Clone)]
1504pub struct FittedHistGradientBoostingClassifier<F> {
1505    /// Bin edge information for each feature.
1506    bin_infos: Vec<FeatureBinInfo<F>>,
1507    /// Sorted unique class labels.
1508    classes: Vec<usize>,
1509    /// Initial predictions per class (log-odds or log-prior).
1510    init: Vec<F>,
1511    /// Learning rate.
1512    learning_rate: F,
1513    /// Trees: for binary, `trees[0]` has all trees. For multiclass,
1514    /// `trees[k]` has trees for class k.
1515    trees: Vec<Vec<Vec<HistNode<F>>>>,
1516    /// Number of features.
1517    n_features: usize,
1518    /// Per-feature importance scores (normalised).
1519    feature_importances: Array1<F>,
1520}
1521
1522impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>>
1523    for HistGradientBoostingClassifier<F>
1524{
1525    type Fitted = FittedHistGradientBoostingClassifier<F>;
1526    type Error = FerroError;
1527
1528    /// Fit the histogram-based gradient boosting classifier.
1529    ///
1530    /// # Errors
1531    ///
1532    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
1533    /// numbers of samples.
1534    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
1535    /// Returns [`FerroError::InvalidParameter`] for invalid hyperparameters.
1536    fn fit(
1537        &self,
1538        x: &Array2<F>,
1539        y: &Array1<usize>,
1540    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1541        let (n_samples, n_features) = x.dim();
1542
1543        if n_samples != y.len() {
1544            return Err(FerroError::ShapeMismatch {
1545                expected: vec![n_samples],
1546                actual: vec![y.len()],
1547                context: "y length must match number of samples in X".into(),
1548            });
1549        }
1550        if n_samples == 0 {
1551            return Err(FerroError::InsufficientSamples {
1552                required: 1,
1553                actual: 0,
1554                context: "HistGradientBoostingClassifier requires at least one sample".into(),
1555            });
1556        }
1557        if self.n_estimators == 0 {
1558            return Err(FerroError::InvalidParameter {
1559                name: "n_estimators".into(),
1560                reason: "must be at least 1".into(),
1561            });
1562        }
1563        if self.learning_rate <= 0.0 {
1564            return Err(FerroError::InvalidParameter {
1565                name: "learning_rate".into(),
1566                reason: "must be positive".into(),
1567            });
1568        }
1569        if self.max_bins < 2 {
1570            return Err(FerroError::InvalidParameter {
1571                name: "max_bins".into(),
1572                reason: "must be at least 2".into(),
1573            });
1574        }
1575
1576        // Determine unique classes.
1577        let mut classes: Vec<usize> = y.iter().copied().collect();
1578        classes.sort_unstable();
1579        classes.dedup();
1580        let n_classes = classes.len();
1581
1582        if n_classes < 2 {
1583            return Err(FerroError::InvalidParameter {
1584                name: "y".into(),
1585                reason: "need at least 2 distinct classes".into(),
1586            });
1587        }
1588
1589        let y_mapped: Vec<usize> = y
1590            .iter()
1591            .map(|&c| classes.iter().position(|&cl| cl == c).unwrap())
1592            .collect();
1593
1594        let lr = F::from(self.learning_rate).unwrap();
1595        let l2_reg = F::from(self.l2_regularization).unwrap();
1596
1597        // Bin the data.
1598        let bin_infos = compute_bin_edges(x, self.max_bins);
1599        let binned = bin_data(x, &bin_infos);
1600
1601        let tree_params = HistTreeParams {
1602            max_depth: self.max_depth,
1603            min_samples_leaf: self.min_samples_leaf,
1604            max_leaf_nodes: self.max_leaf_nodes,
1605            l2_regularization: l2_reg,
1606        };
1607
1608        let all_indices: Vec<usize> = (0..n_samples).collect();
1609
1610        if n_classes == 2 {
1611            self.fit_binary(
1612                &binned,
1613                &y_mapped,
1614                n_samples,
1615                n_features,
1616                &classes,
1617                lr,
1618                &tree_params,
1619                &all_indices,
1620                &bin_infos,
1621            )
1622        } else {
1623            self.fit_multiclass(
1624                &binned,
1625                &y_mapped,
1626                n_samples,
1627                n_features,
1628                n_classes,
1629                &classes,
1630                lr,
1631                &tree_params,
1632                &all_indices,
1633                &bin_infos,
1634            )
1635        }
1636    }
1637}
1638
1639impl<F: Float + Send + Sync + 'static> HistGradientBoostingClassifier<F> {
1640    /// Fit binary classification (log-loss on log-odds).
1641    #[allow(clippy::too_many_arguments)]
1642    fn fit_binary(
1643        &self,
1644        binned: &[Vec<u16>],
1645        y_mapped: &[usize],
1646        n_samples: usize,
1647        n_features: usize,
1648        classes: &[usize],
1649        lr: F,
1650        tree_params: &HistTreeParams<F>,
1651        all_indices: &[usize],
1652        bin_infos: &[FeatureBinInfo<F>],
1653    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1654        // Initial log-odds.
1655        let pos_count = y_mapped.iter().filter(|&&c| c == 1).count();
1656        let p = F::from(pos_count).unwrap() / F::from(n_samples).unwrap();
1657        let eps = F::from(1e-15).unwrap();
1658        let p_clipped = p.max(eps).min(F::one() - eps);
1659        let init_val = (p_clipped / (F::one() - p_clipped)).ln();
1660
1661        let mut f_vals = Array1::from_elem(n_samples, init_val);
1662        let mut trees_seq: Vec<Vec<HistNode<F>>> = Vec::with_capacity(self.n_estimators);
1663
1664        for _ in 0..self.n_estimators {
1665            // Compute probabilities.
1666            let probs: Vec<F> = f_vals.iter().map(|&fv| sigmoid(fv)).collect();
1667
1668            // Gradients and hessians for log-loss:
1669            //   gradient = p - y (we negate pseudo-residual because tree fits -gradient)
1670            //   hessian = p * (1 - p)
1671            let gradients: Vec<F> = (0..n_samples)
1672                .map(|i| {
1673                    let yi = F::from(y_mapped[i]).unwrap();
1674                    probs[i] - yi
1675                })
1676                .collect();
1677            let hessians: Vec<F> = (0..n_samples)
1678                .map(|i| {
1679                    let pi = probs[i].max(eps).min(F::one() - eps);
1680                    pi * (F::one() - pi)
1681                })
1682                .collect();
1683
1684            let tree = build_hist_tree(
1685                binned,
1686                &gradients,
1687                &hessians,
1688                all_indices,
1689                bin_infos,
1690                tree_params,
1691            );
1692
1693            // Update f_vals.
1694            for i in 0..n_samples {
1695                let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1696                if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1697                    f_vals[i] = f_vals[i] + lr * value;
1698                }
1699            }
1700
1701            trees_seq.push(tree);
1702        }
1703
1704        // Feature importances.
1705        let mut total_importances = Array1::<F>::zeros(n_features);
1706        for tree_nodes in &trees_seq {
1707            total_importances =
1708                total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1709        }
1710        let imp_sum: F = total_importances
1711            .iter()
1712            .copied()
1713            .fold(F::zero(), |a, b| a + b);
1714        if imp_sum > F::zero() {
1715            total_importances.mapv_inplace(|v| v / imp_sum);
1716        }
1717
1718        Ok(FittedHistGradientBoostingClassifier {
1719            bin_infos: bin_infos.to_vec(),
1720            classes: classes.to_vec(),
1721            init: vec![init_val],
1722            learning_rate: lr,
1723            trees: vec![trees_seq],
1724            n_features,
1725            feature_importances: total_importances,
1726        })
1727    }
1728
1729    /// Fit multiclass classification (K trees per round, softmax).
1730    #[allow(clippy::too_many_arguments)]
1731    fn fit_multiclass(
1732        &self,
1733        binned: &[Vec<u16>],
1734        y_mapped: &[usize],
1735        n_samples: usize,
1736        n_features: usize,
1737        n_classes: usize,
1738        classes: &[usize],
1739        lr: F,
1740        tree_params: &HistTreeParams<F>,
1741        all_indices: &[usize],
1742        bin_infos: &[FeatureBinInfo<F>],
1743    ) -> Result<FittedHistGradientBoostingClassifier<F>, FerroError> {
1744        // Initial log-prior for each class.
1745        let mut class_counts = vec![0usize; n_classes];
1746        for &c in y_mapped {
1747            class_counts[c] += 1;
1748        }
1749        let n_f = F::from(n_samples).unwrap();
1750        let eps = F::from(1e-15).unwrap();
1751        let init_vals: Vec<F> = class_counts
1752            .iter()
1753            .map(|&cnt| {
1754                let p = (F::from(cnt).unwrap() / n_f).max(eps);
1755                p.ln()
1756            })
1757            .collect();
1758
1759        let mut f_vals: Vec<Array1<F>> = init_vals
1760            .iter()
1761            .map(|&init| Array1::from_elem(n_samples, init))
1762            .collect();
1763
1764        let mut trees_per_class: Vec<Vec<Vec<HistNode<F>>>> = (0..n_classes)
1765            .map(|_| Vec::with_capacity(self.n_estimators))
1766            .collect();
1767
1768        for _ in 0..self.n_estimators {
1769            let probs = softmax_matrix(&f_vals, n_samples, n_classes);
1770
1771            for k in 0..n_classes {
1772                // Gradients and hessians for softmax cross-entropy:
1773                //   gradient_k = p_k - y_k
1774                //   hessian_k  = p_k * (1 - p_k)
1775                let gradients: Vec<F> = (0..n_samples)
1776                    .map(|i| {
1777                        let yi_k = if y_mapped[i] == k {
1778                            F::one()
1779                        } else {
1780                            F::zero()
1781                        };
1782                        probs[k][i] - yi_k
1783                    })
1784                    .collect();
1785                let hessians: Vec<F> = (0..n_samples)
1786                    .map(|i| {
1787                        let pk = probs[k][i].max(eps).min(F::one() - eps);
1788                        pk * (F::one() - pk)
1789                    })
1790                    .collect();
1791
1792                let tree = build_hist_tree(
1793                    binned,
1794                    &gradients,
1795                    &hessians,
1796                    all_indices,
1797                    bin_infos,
1798                    tree_params,
1799                );
1800
1801                // Update f_vals for class k.
1802                for (i, fv) in f_vals[k].iter_mut().enumerate() {
1803                    let leaf_idx = traverse_hist_tree(&tree, &binned[i]);
1804                    if let HistNode::Leaf { value, .. } = tree[leaf_idx] {
1805                        *fv = *fv + lr * value;
1806                    }
1807                }
1808
1809                trees_per_class[k].push(tree);
1810            }
1811        }
1812
1813        // Feature importances aggregated across all classes and rounds.
1814        let mut total_importances = Array1::<F>::zeros(n_features);
1815        for class_trees in &trees_per_class {
1816            for tree_nodes in class_trees {
1817                total_importances =
1818                    total_importances + compute_hist_feature_importances(tree_nodes, n_features);
1819            }
1820        }
1821        let imp_sum: F = total_importances
1822            .iter()
1823            .copied()
1824            .fold(F::zero(), |a, b| a + b);
1825        if imp_sum > F::zero() {
1826            total_importances.mapv_inplace(|v| v / imp_sum);
1827        }
1828
1829        Ok(FittedHistGradientBoostingClassifier {
1830            bin_infos: bin_infos.to_vec(),
1831            classes: classes.to_vec(),
1832            init: init_vals,
1833            learning_rate: lr,
1834            trees: trees_per_class,
1835            n_features,
1836            feature_importances: total_importances,
1837        })
1838    }
1839}
1840
1841impl<F: Float + Send + Sync + 'static> Predict<Array2<F>>
1842    for FittedHistGradientBoostingClassifier<F>
1843{
1844    type Output = Array1<usize>;
1845    type Error = FerroError;
1846
1847    /// Predict class labels.
1848    ///
1849    /// # Errors
1850    ///
1851    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
1852    /// not match the fitted model.
1853    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
1854        if x.ncols() != self.n_features {
1855            return Err(FerroError::ShapeMismatch {
1856                expected: vec![self.n_features],
1857                actual: vec![x.ncols()],
1858                context: "number of features must match fitted model".into(),
1859            });
1860        }
1861
1862        let n_samples = x.nrows();
1863        let n_classes = self.classes.len();
1864        let binned = bin_data(x, &self.bin_infos);
1865
1866        if n_classes == 2 {
1867            let init = self.init[0];
1868            let mut predictions = Array1::zeros(n_samples);
1869            for i in 0..n_samples {
1870                let mut f_val = init;
1871                for tree_nodes in &self.trees[0] {
1872                    let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1873                    if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1874                        f_val = f_val + self.learning_rate * value;
1875                    }
1876                }
1877                let prob = sigmoid(f_val);
1878                let class_idx = if prob >= F::from(0.5).unwrap() { 1 } else { 0 };
1879                predictions[i] = self.classes[class_idx];
1880            }
1881            Ok(predictions)
1882        } else {
1883            let mut predictions = Array1::zeros(n_samples);
1884            for i in 0..n_samples {
1885                let mut scores = Vec::with_capacity(n_classes);
1886                for k in 0..n_classes {
1887                    let mut f_val = self.init[k];
1888                    for tree_nodes in &self.trees[k] {
1889                        let leaf_idx = traverse_hist_tree(tree_nodes, &binned[i]);
1890                        if let HistNode::Leaf { value, .. } = tree_nodes[leaf_idx] {
1891                            f_val = f_val + self.learning_rate * value;
1892                        }
1893                    }
1894                    scores.push(f_val);
1895                }
1896                let best_k = scores
1897                    .iter()
1898                    .enumerate()
1899                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
1900                    .map(|(k, _)| k)
1901                    .unwrap_or(0);
1902                predictions[i] = self.classes[best_k];
1903            }
1904            Ok(predictions)
1905        }
1906    }
1907}
1908
1909impl<F: Float + Send + Sync + 'static> HasFeatureImportances<F>
1910    for FittedHistGradientBoostingClassifier<F>
1911{
1912    fn feature_importances(&self) -> &Array1<F> {
1913        &self.feature_importances
1914    }
1915}
1916
1917impl<F: Float + Send + Sync + 'static> HasClasses for FittedHistGradientBoostingClassifier<F> {
1918    fn classes(&self) -> &[usize] {
1919        &self.classes
1920    }
1921
1922    fn n_classes(&self) -> usize {
1923        self.classes.len()
1924    }
1925}
1926
1927// Pipeline integration.
1928impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
1929    for HistGradientBoostingClassifier<F>
1930{
1931    fn fit_pipeline(
1932        &self,
1933        x: &Array2<F>,
1934        y: &Array1<F>,
1935    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1936        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
1937        let fitted = self.fit(x, &y_usize)?;
1938        Ok(Box::new(FittedHgbcPipelineAdapter(fitted)))
1939    }
1940}
1941
1942/// Pipeline adapter for `FittedHistGradientBoostingClassifier<F>`.
1943struct FittedHgbcPipelineAdapter<F: Float + Send + Sync + 'static>(
1944    FittedHistGradientBoostingClassifier<F>,
1945);
1946
1947impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
1948    for FittedHgbcPipelineAdapter<F>
1949{
1950    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1951        let preds = self.0.predict(x)?;
1952        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
1953    }
1954}
1955
1956// ---------------------------------------------------------------------------
1957// Internal helpers
1958// ---------------------------------------------------------------------------
1959
1960/// Compute the median of an `Array1`.
1961fn median_f<F: Float>(arr: &Array1<F>) -> F {
1962    let mut sorted: Vec<F> = arr.iter().copied().collect();
1963    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1964    let n = sorted.len();
1965    if n == 0 {
1966        return F::zero();
1967    }
1968    if n % 2 == 1 {
1969        sorted[n / 2]
1970    } else {
1971        (sorted[n / 2 - 1] + sorted[n / 2]) / F::from(2.0).unwrap()
1972    }
1973}
1974
1975// ---------------------------------------------------------------------------
1976// Tests
1977// ---------------------------------------------------------------------------
1978
1979#[cfg(test)]
1980mod tests {
1981    use super::*;
1982    use approx::assert_relative_eq;
1983    use ndarray::array;
1984
1985    // -- Binning tests --
1986
1987    #[test]
1988    fn test_bin_edges_simple() {
1989        let x =
1990            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1991        let infos = compute_bin_edges(&x, 4);
1992        assert_eq!(infos.len(), 1);
1993        // Should have up to 4 bins.
1994        assert!(infos[0].n_bins <= 4);
1995        assert!(infos[0].n_bins >= 2);
1996        assert!(!infos[0].has_nan);
1997    }
1998
1999    #[test]
2000    fn test_bin_edges_with_nan() {
2001        let x = Array2::from_shape_vec((5, 1), vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0]).unwrap();
2002        let infos = compute_bin_edges(&x, 4);
2003        assert_eq!(infos.len(), 1);
2004        assert!(infos[0].has_nan);
2005        assert!(infos[0].n_bins >= 1);
2006    }
2007
2008    #[test]
2009    fn test_bin_edges_all_nan() {
2010        let x = Array2::from_shape_vec((3, 1), vec![f64::NAN, f64::NAN, f64::NAN]).unwrap();
2011        let infos = compute_bin_edges(&x, 4);
2012        assert_eq!(infos[0].n_bins, 0);
2013        assert!(infos[0].has_nan);
2014    }
2015
2016    #[test]
2017    fn test_map_to_bin_basic() {
2018        let info = FeatureBinInfo {
2019            edges: vec![2.0, 4.0, 6.0, 8.0],
2020            n_bins: 4,
2021            has_nan: false,
2022        };
2023        assert_eq!(map_to_bin(1.0, &info), 0);
2024        assert_eq!(map_to_bin(2.0, &info), 0);
2025        assert_eq!(map_to_bin(3.0, &info), 1);
2026        assert_eq!(map_to_bin(5.0, &info), 2);
2027        assert_eq!(map_to_bin(7.0, &info), 3);
2028        assert_eq!(map_to_bin(9.0, &info), 3);
2029        assert_eq!(map_to_bin(f64::NAN, &info), NAN_BIN);
2030    }
2031
2032    #[test]
2033    fn test_bin_data_roundtrip() {
2034        let x = Array2::from_shape_vec((4, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0])
2035            .unwrap();
2036        let infos = compute_bin_edges(&x, 255);
2037        let binned = bin_data(&x, &infos);
2038        assert_eq!(binned.len(), 4);
2039        assert_eq!(binned[0].len(), 2);
2040        // Values should be monotonically non-decreasing since input is sorted.
2041        for j in 0..2 {
2042            for i in 1..4 {
2043                assert!(binned[i][j] >= binned[i - 1][j]);
2044            }
2045        }
2046    }
2047
2048    // -- Subtraction trick test --
2049
2050    #[test]
2051    fn test_subtraction_trick() {
2052        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2053        let bin_infos = compute_bin_edges(&x, 255);
2054        let binned = bin_data(&x, &bin_infos);
2055
2056        let gradients = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6];
2057        let hessians = vec![1.0; 6];
2058        let all_indices: Vec<usize> = (0..6).collect();
2059        let left_indices: Vec<usize> = vec![0, 1, 2];
2060        let right_indices: Vec<usize> = vec![3, 4, 5];
2061
2062        let parent_hist =
2063            build_histograms(&binned, &gradients, &hessians, &all_indices, &bin_infos);
2064        let left_hist = build_histograms(&binned, &gradients, &hessians, &left_indices, &bin_infos);
2065        let right_from_sub = subtract_histograms(&parent_hist, &left_hist, &bin_infos);
2066        let right_direct =
2067            build_histograms(&binned, &gradients, &hessians, &right_indices, &bin_infos);
2068
2069        // The subtraction trick result should match direct computation.
2070        for j in 0..bin_infos.len() {
2071            let n_bins = bin_infos[j].n_bins as usize;
2072            for b in 0..n_bins {
2073                assert_relative_eq!(
2074                    right_from_sub[j].bins[b].grad_sum,
2075                    right_direct[j].bins[b].grad_sum,
2076                    epsilon = 1e-10
2077                );
2078                assert_relative_eq!(
2079                    right_from_sub[j].bins[b].hess_sum,
2080                    right_direct[j].bins[b].hess_sum,
2081                    epsilon = 1e-10
2082                );
2083                assert_eq!(
2084                    right_from_sub[j].bins[b].count,
2085                    right_direct[j].bins[b].count
2086                );
2087            }
2088        }
2089    }
2090
2091    // -- Regressor tests --
2092
2093    #[test]
2094    fn test_hgbr_simple_least_squares() {
2095        let x =
2096            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2097        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2098
2099        let model = HistGradientBoostingRegressor::<f64>::new()
2100            .with_n_estimators(50)
2101            .with_learning_rate(0.1)
2102            .with_min_samples_leaf(1)
2103            .with_max_leaf_nodes(None)
2104            .with_max_depth(Some(3))
2105            .with_random_state(42);
2106        let fitted = model.fit(&x, &y).unwrap();
2107        let preds = fitted.predict(&x).unwrap();
2108
2109        assert_eq!(preds.len(), 8);
2110        for i in 0..4 {
2111            assert!(preds[i] < 3.0, "Expected ~1.0, got {}", preds[i]);
2112        }
2113        for i in 4..8 {
2114            assert!(preds[i] > 3.0, "Expected ~5.0, got {}", preds[i]);
2115        }
2116    }
2117
2118    #[test]
2119    fn test_hgbr_lad_loss() {
2120        let x =
2121            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2122        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2123
2124        let model = HistGradientBoostingRegressor::<f64>::new()
2125            .with_n_estimators(50)
2126            .with_loss(HistRegressionLoss::LeastAbsoluteDeviation)
2127            .with_min_samples_leaf(1)
2128            .with_max_leaf_nodes(None)
2129            .with_max_depth(Some(3))
2130            .with_random_state(42);
2131        let fitted = model.fit(&x, &y).unwrap();
2132        let preds = fitted.predict(&x).unwrap();
2133
2134        assert_eq!(preds.len(), 8);
2135        for i in 0..4 {
2136            assert!(preds[i] < 3.5, "LAD expected <3.5, got {}", preds[i]);
2137        }
2138        for i in 4..8 {
2139            assert!(preds[i] > 2.5, "LAD expected >2.5, got {}", preds[i]);
2140        }
2141    }
2142
2143    #[test]
2144    fn test_hgbr_reproducibility() {
2145        let x =
2146            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2147        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2148
2149        let model = HistGradientBoostingRegressor::<f64>::new()
2150            .with_n_estimators(20)
2151            .with_min_samples_leaf(1)
2152            .with_max_leaf_nodes(None)
2153            .with_max_depth(Some(3))
2154            .with_random_state(123);
2155
2156        let fitted1 = model.fit(&x, &y).unwrap();
2157        let fitted2 = model.fit(&x, &y).unwrap();
2158
2159        let preds1 = fitted1.predict(&x).unwrap();
2160        let preds2 = fitted2.predict(&x).unwrap();
2161
2162        for (p1, p2) in preds1.iter().zip(preds2.iter()) {
2163            assert_relative_eq!(*p1, *p2, epsilon = 1e-10);
2164        }
2165    }
2166
2167    #[test]
2168    fn test_hgbr_feature_importances() {
2169        let x = Array2::from_shape_vec(
2170            (10, 3),
2171            vec![
2172                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
2173                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
2174            ],
2175        )
2176        .unwrap();
2177        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
2178
2179        let model = HistGradientBoostingRegressor::<f64>::new()
2180            .with_n_estimators(20)
2181            .with_min_samples_leaf(1)
2182            .with_max_leaf_nodes(None)
2183            .with_max_depth(Some(3))
2184            .with_random_state(42);
2185        let fitted = model.fit(&x, &y).unwrap();
2186        let importances = fitted.feature_importances();
2187
2188        assert_eq!(importances.len(), 3);
2189        // First feature should be most important since it's the only one with variance.
2190        assert!(
2191            importances[0] > importances[1],
2192            "Expected imp[0]={} > imp[1]={}",
2193            importances[0],
2194            importances[1]
2195        );
2196    }
2197
2198    #[test]
2199    fn test_hgbr_shape_mismatch_fit() {
2200        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2201        let y = array![1.0, 2.0];
2202
2203        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(5);
2204        assert!(model.fit(&x, &y).is_err());
2205    }
2206
2207    #[test]
2208    fn test_hgbr_shape_mismatch_predict() {
2209        let x =
2210            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2211        let y = array![1.0, 2.0, 3.0, 4.0];
2212
2213        let model = HistGradientBoostingRegressor::<f64>::new()
2214            .with_n_estimators(5)
2215            .with_min_samples_leaf(1)
2216            .with_max_leaf_nodes(None)
2217            .with_max_depth(Some(3))
2218            .with_random_state(0);
2219        let fitted = model.fit(&x, &y).unwrap();
2220
2221        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2222        assert!(fitted.predict(&x_bad).is_err());
2223    }
2224
2225    #[test]
2226    fn test_hgbr_empty_data() {
2227        let x = Array2::<f64>::zeros((0, 2));
2228        let y = Array1::<f64>::zeros(0);
2229
2230        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(5);
2231        assert!(model.fit(&x, &y).is_err());
2232    }
2233
2234    #[test]
2235    fn test_hgbr_zero_estimators() {
2236        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2237        let y = array![1.0, 2.0, 3.0, 4.0];
2238
2239        let model = HistGradientBoostingRegressor::<f64>::new().with_n_estimators(0);
2240        assert!(model.fit(&x, &y).is_err());
2241    }
2242
2243    #[test]
2244    fn test_hgbr_invalid_learning_rate() {
2245        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2246        let y = array![1.0, 2.0, 3.0, 4.0];
2247
2248        let model = HistGradientBoostingRegressor::<f64>::new()
2249            .with_n_estimators(5)
2250            .with_learning_rate(0.0);
2251        assert!(model.fit(&x, &y).is_err());
2252    }
2253
2254    #[test]
2255    fn test_hgbr_invalid_max_bins() {
2256        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2257        let y = array![1.0, 2.0, 3.0, 4.0];
2258
2259        let model = HistGradientBoostingRegressor::<f64>::new()
2260            .with_n_estimators(5)
2261            .with_max_bins(1);
2262        assert!(model.fit(&x, &y).is_err());
2263    }
2264
2265    #[test]
2266    fn test_hgbr_default_trait() {
2267        let model = HistGradientBoostingRegressor::<f64>::default();
2268        assert_eq!(model.n_estimators, 100);
2269        assert!((model.learning_rate - 0.1).abs() < 1e-10);
2270        assert_eq!(model.max_bins, 255);
2271    }
2272
2273    #[test]
2274    fn test_hgbr_pipeline_integration() {
2275        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2276        let y = array![1.0, 2.0, 3.0, 4.0];
2277
2278        let model = HistGradientBoostingRegressor::<f64>::new()
2279            .with_n_estimators(10)
2280            .with_min_samples_leaf(1)
2281            .with_max_leaf_nodes(None)
2282            .with_max_depth(Some(3))
2283            .with_random_state(42);
2284        let fitted = model.fit_pipeline(&x, &y).unwrap();
2285        let preds = fitted.predict_pipeline(&x).unwrap();
2286        assert_eq!(preds.len(), 4);
2287    }
2288
2289    #[test]
2290    fn test_hgbr_f32_support() {
2291        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
2292        let y = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
2293
2294        let model = HistGradientBoostingRegressor::<f32>::new()
2295            .with_n_estimators(10)
2296            .with_min_samples_leaf(1)
2297            .with_max_leaf_nodes(None)
2298            .with_max_depth(Some(3))
2299            .with_random_state(42);
2300        let fitted = model.fit(&x, &y).unwrap();
2301        let preds = fitted.predict(&x).unwrap();
2302        assert_eq!(preds.len(), 4);
2303    }
2304
2305    #[test]
2306    fn test_hgbr_nan_handling() {
2307        // Some features have NaN — the model should handle them gracefully.
2308        let x = Array2::from_shape_vec(
2309            (8, 1),
2310            vec![1.0, f64::NAN, 3.0, 4.0, 5.0, f64::NAN, 7.0, 8.0],
2311        )
2312        .unwrap();
2313        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2314
2315        let model = HistGradientBoostingRegressor::<f64>::new()
2316            .with_n_estimators(20)
2317            .with_min_samples_leaf(1)
2318            .with_max_leaf_nodes(None)
2319            .with_max_depth(Some(3))
2320            .with_random_state(42);
2321        let fitted = model.fit(&x, &y).unwrap();
2322        let preds = fitted.predict(&x).unwrap();
2323        assert_eq!(preds.len(), 8);
2324        // Should still produce finite predictions for all samples (including NaN inputs).
2325        for p in preds.iter() {
2326            assert!(p.is_finite(), "Expected finite prediction, got {}", p);
2327        }
2328    }
2329
2330    #[test]
2331    fn test_hgbr_convergence() {
2332        // MSE should decrease as we add more estimators.
2333        let x =
2334            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2335        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2336
2337        let mse = |preds: &Array1<f64>, y: &Array1<f64>| -> f64 {
2338            preds
2339                .iter()
2340                .zip(y.iter())
2341                .map(|(p, t)| (p - t).powi(2))
2342                .sum::<f64>()
2343                / y.len() as f64
2344        };
2345
2346        let model_few = HistGradientBoostingRegressor::<f64>::new()
2347            .with_n_estimators(5)
2348            .with_min_samples_leaf(1)
2349            .with_max_leaf_nodes(None)
2350            .with_max_depth(Some(3))
2351            .with_random_state(42);
2352        let fitted_few = model_few.fit(&x, &y).unwrap();
2353        let preds_few = fitted_few.predict(&x).unwrap();
2354        let mse_few = mse(&preds_few, &y);
2355
2356        let model_many = HistGradientBoostingRegressor::<f64>::new()
2357            .with_n_estimators(50)
2358            .with_min_samples_leaf(1)
2359            .with_max_leaf_nodes(None)
2360            .with_max_depth(Some(3))
2361            .with_random_state(42);
2362        let fitted_many = model_many.fit(&x, &y).unwrap();
2363        let preds_many = fitted_many.predict(&x).unwrap();
2364        let mse_many = mse(&preds_many, &y);
2365
2366        assert!(
2367            mse_many < mse_few,
2368            "Expected MSE to decrease with more estimators: {} (50) vs {} (5)",
2369            mse_many,
2370            mse_few
2371        );
2372    }
2373
2374    #[test]
2375    fn test_hgbr_max_leaf_nodes() {
2376        // Test that best-first growth with max_leaf_nodes works.
2377        let x =
2378            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2379        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2380
2381        let model = HistGradientBoostingRegressor::<f64>::new()
2382            .with_n_estimators(20)
2383            .with_min_samples_leaf(1)
2384            .with_max_leaf_nodes(Some(4))
2385            .with_random_state(42);
2386        let fitted = model.fit(&x, &y).unwrap();
2387        let preds = fitted.predict(&x).unwrap();
2388        assert_eq!(preds.len(), 8);
2389    }
2390
2391    #[test]
2392    fn test_hgbr_l2_regularization() {
2393        let x =
2394            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2395        let y = array![1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0];
2396
2397        // With very high regularization, predictions should be closer to the mean.
2398        let model_noreg = HistGradientBoostingRegressor::<f64>::new()
2399            .with_n_estimators(20)
2400            .with_min_samples_leaf(1)
2401            .with_max_leaf_nodes(None)
2402            .with_max_depth(Some(3))
2403            .with_l2_regularization(0.0)
2404            .with_random_state(42);
2405        let fitted_noreg = model_noreg.fit(&x, &y).unwrap();
2406        let preds_noreg = fitted_noreg.predict(&x).unwrap();
2407
2408        let model_highreg = HistGradientBoostingRegressor::<f64>::new()
2409            .with_n_estimators(20)
2410            .with_min_samples_leaf(1)
2411            .with_max_leaf_nodes(None)
2412            .with_max_depth(Some(3))
2413            .with_l2_regularization(100.0)
2414            .with_random_state(42);
2415        let fitted_highreg = model_highreg.fit(&x, &y).unwrap();
2416        let preds_highreg = fitted_highreg.predict(&x).unwrap();
2417
2418        // With high reg, variance of predictions should be smaller.
2419        let var = |preds: &Array1<f64>| -> f64 {
2420            let mean = preds.mean().unwrap();
2421            preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64
2422        };
2423
2424        assert!(
2425            var(&preds_highreg) < var(&preds_noreg),
2426            "High regularization should reduce prediction variance"
2427        );
2428    }
2429
2430    // -- Classifier tests --
2431
2432    #[test]
2433    fn test_hgbc_binary_simple() {
2434        let x = Array2::from_shape_vec(
2435            (8, 2),
2436            vec![
2437                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
2438            ],
2439        )
2440        .unwrap();
2441        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2442
2443        let model = HistGradientBoostingClassifier::<f64>::new()
2444            .with_n_estimators(50)
2445            .with_learning_rate(0.1)
2446            .with_min_samples_leaf(1)
2447            .with_max_leaf_nodes(None)
2448            .with_max_depth(Some(3))
2449            .with_random_state(42);
2450        let fitted = model.fit(&x, &y).unwrap();
2451        let preds = fitted.predict(&x).unwrap();
2452
2453        assert_eq!(preds.len(), 8);
2454        for i in 0..4 {
2455            assert_eq!(preds[i], 0, "Expected 0 at index {}, got {}", i, preds[i]);
2456        }
2457        for i in 4..8 {
2458            assert_eq!(preds[i], 1, "Expected 1 at index {}, got {}", i, preds[i]);
2459        }
2460    }
2461
2462    #[test]
2463    fn test_hgbc_multiclass() {
2464        let x = Array2::from_shape_vec((9, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
2465            .unwrap();
2466        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
2467
2468        let model = HistGradientBoostingClassifier::<f64>::new()
2469            .with_n_estimators(50)
2470            .with_learning_rate(0.1)
2471            .with_min_samples_leaf(1)
2472            .with_max_leaf_nodes(None)
2473            .with_max_depth(Some(3))
2474            .with_random_state(42);
2475        let fitted = model.fit(&x, &y).unwrap();
2476        let preds = fitted.predict(&x).unwrap();
2477
2478        assert_eq!(preds.len(), 9);
2479        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
2480        assert!(
2481            correct >= 6,
2482            "Expected at least 6/9 correct, got {}/9",
2483            correct
2484        );
2485    }
2486
2487    #[test]
2488    fn test_hgbc_has_classes() {
2489        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2490        let y = array![0, 1, 2, 0, 1, 2];
2491
2492        let model = HistGradientBoostingClassifier::<f64>::new()
2493            .with_n_estimators(5)
2494            .with_min_samples_leaf(1)
2495            .with_max_leaf_nodes(None)
2496            .with_max_depth(Some(3))
2497            .with_random_state(0);
2498        let fitted = model.fit(&x, &y).unwrap();
2499
2500        assert_eq!(fitted.classes(), &[0, 1, 2]);
2501        assert_eq!(fitted.n_classes(), 3);
2502    }
2503
2504    #[test]
2505    fn test_hgbc_reproducibility() {
2506        let x = Array2::from_shape_vec(
2507            (8, 2),
2508            vec![
2509                1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
2510            ],
2511        )
2512        .unwrap();
2513        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2514
2515        let model = HistGradientBoostingClassifier::<f64>::new()
2516            .with_n_estimators(10)
2517            .with_min_samples_leaf(1)
2518            .with_max_leaf_nodes(None)
2519            .with_max_depth(Some(3))
2520            .with_random_state(42);
2521
2522        let fitted1 = model.fit(&x, &y).unwrap();
2523        let fitted2 = model.fit(&x, &y).unwrap();
2524
2525        let preds1 = fitted1.predict(&x).unwrap();
2526        let preds2 = fitted2.predict(&x).unwrap();
2527        assert_eq!(preds1, preds2);
2528    }
2529
2530    #[test]
2531    fn test_hgbc_feature_importances() {
2532        let x = Array2::from_shape_vec(
2533            (10, 3),
2534            vec![
2535                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
2536                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
2537            ],
2538        )
2539        .unwrap();
2540        let y = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
2541
2542        let model = HistGradientBoostingClassifier::<f64>::new()
2543            .with_n_estimators(20)
2544            .with_min_samples_leaf(1)
2545            .with_max_leaf_nodes(None)
2546            .with_max_depth(Some(3))
2547            .with_random_state(42);
2548        let fitted = model.fit(&x, &y).unwrap();
2549        let importances = fitted.feature_importances();
2550
2551        assert_eq!(importances.len(), 3);
2552        assert!(importances[0] > importances[1]);
2553    }
2554
2555    #[test]
2556    fn test_hgbc_shape_mismatch_fit() {
2557        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2558        let y = array![0, 1];
2559
2560        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2561        assert!(model.fit(&x, &y).is_err());
2562    }
2563
2564    #[test]
2565    fn test_hgbc_shape_mismatch_predict() {
2566        let x =
2567            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
2568        let y = array![0, 0, 1, 1];
2569
2570        let model = HistGradientBoostingClassifier::<f64>::new()
2571            .with_n_estimators(5)
2572            .with_min_samples_leaf(1)
2573            .with_max_leaf_nodes(None)
2574            .with_max_depth(Some(3))
2575            .with_random_state(0);
2576        let fitted = model.fit(&x, &y).unwrap();
2577
2578        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2579        assert!(fitted.predict(&x_bad).is_err());
2580    }
2581
2582    #[test]
2583    fn test_hgbc_empty_data() {
2584        let x = Array2::<f64>::zeros((0, 2));
2585        let y = Array1::<usize>::zeros(0);
2586
2587        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2588        assert!(model.fit(&x, &y).is_err());
2589    }
2590
2591    #[test]
2592    fn test_hgbc_single_class() {
2593        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
2594        let y = array![0, 0, 0];
2595
2596        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(5);
2597        assert!(model.fit(&x, &y).is_err());
2598    }
2599
2600    #[test]
2601    fn test_hgbc_zero_estimators() {
2602        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2603        let y = array![0, 0, 1, 1];
2604
2605        let model = HistGradientBoostingClassifier::<f64>::new().with_n_estimators(0);
2606        assert!(model.fit(&x, &y).is_err());
2607    }
2608
2609    #[test]
2610    fn test_hgbc_pipeline_integration() {
2611        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2612        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
2613
2614        let model = HistGradientBoostingClassifier::<f64>::new()
2615            .with_n_estimators(10)
2616            .with_min_samples_leaf(1)
2617            .with_max_leaf_nodes(None)
2618            .with_max_depth(Some(3))
2619            .with_random_state(42);
2620        let fitted = model.fit_pipeline(&x, &y).unwrap();
2621        let preds = fitted.predict_pipeline(&x).unwrap();
2622        assert_eq!(preds.len(), 6);
2623    }
2624
2625    #[test]
2626    fn test_hgbc_f32_support() {
2627        let x = Array2::from_shape_vec((6, 1), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2628        let y = array![0, 0, 0, 1, 1, 1];
2629
2630        let model = HistGradientBoostingClassifier::<f32>::new()
2631            .with_n_estimators(10)
2632            .with_min_samples_leaf(1)
2633            .with_max_leaf_nodes(None)
2634            .with_max_depth(Some(3))
2635            .with_random_state(42);
2636        let fitted = model.fit(&x, &y).unwrap();
2637        let preds = fitted.predict(&x).unwrap();
2638        assert_eq!(preds.len(), 6);
2639    }
2640
2641    #[test]
2642    fn test_hgbc_default_trait() {
2643        let model = HistGradientBoostingClassifier::<f64>::default();
2644        assert_eq!(model.n_estimators, 100);
2645        assert!((model.learning_rate - 0.1).abs() < 1e-10);
2646        assert_eq!(model.max_bins, 255);
2647    }
2648
2649    #[test]
2650    fn test_hgbc_non_contiguous_labels() {
2651        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
2652        let y = array![10, 10, 10, 20, 20, 20];
2653
2654        let model = HistGradientBoostingClassifier::<f64>::new()
2655            .with_n_estimators(20)
2656            .with_min_samples_leaf(1)
2657            .with_max_leaf_nodes(None)
2658            .with_max_depth(Some(3))
2659            .with_random_state(42);
2660        let fitted = model.fit(&x, &y).unwrap();
2661        let preds = fitted.predict(&x).unwrap();
2662
2663        assert_eq!(preds.len(), 6);
2664        for &p in preds.iter() {
2665            assert!(p == 10 || p == 20);
2666        }
2667    }
2668
2669    #[test]
2670    fn test_hgbc_nan_handling() {
2671        // Classifier should handle NaN features.
2672        let x = Array2::from_shape_vec(
2673            (8, 2),
2674            vec![
2675                1.0,
2676                f64::NAN,
2677                2.0,
2678                3.0,
2679                f64::NAN,
2680                3.0,
2681                4.0,
2682                4.0,
2683                5.0,
2684                6.0,
2685                6.0,
2686                f64::NAN,
2687                7.0,
2688                8.0,
2689                f64::NAN,
2690                9.0,
2691            ],
2692        )
2693        .unwrap();
2694        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
2695
2696        let model = HistGradientBoostingClassifier::<f64>::new()
2697            .with_n_estimators(20)
2698            .with_min_samples_leaf(1)
2699            .with_max_leaf_nodes(None)
2700            .with_max_depth(Some(3))
2701            .with_random_state(42);
2702        let fitted = model.fit(&x, &y).unwrap();
2703        let preds = fitted.predict(&x).unwrap();
2704        assert_eq!(preds.len(), 8);
2705        // All predictions should be valid class labels.
2706        for &p in preds.iter() {
2707            assert!(p == 0 || p == 1);
2708        }
2709    }
2710
2711    // -- Comparison with standard GBM --
2712
2713    #[test]
2714    fn test_hist_vs_standard_gbm_similar_accuracy() {
2715        // Both models should achieve comparable accuracy on a simple task.
2716        use crate::GradientBoostingRegressor;
2717
2718        let x = Array2::from_shape_vec(
2719            (10, 1),
2720            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
2721        )
2722        .unwrap();
2723        let y = array![1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0];
2724
2725        let mse = |preds: &Array1<f64>, y: &Array1<f64>| -> f64 {
2726            preds
2727                .iter()
2728                .zip(y.iter())
2729                .map(|(p, t)| (p - t).powi(2))
2730                .sum::<f64>()
2731                / y.len() as f64
2732        };
2733
2734        // Standard GBM.
2735        let std_model = GradientBoostingRegressor::<f64>::new()
2736            .with_n_estimators(50)
2737            .with_learning_rate(0.1)
2738            .with_max_depth(Some(3))
2739            .with_random_state(42);
2740        let std_fitted = std_model.fit(&x, &y).unwrap();
2741        let std_preds = std_fitted.predict(&x).unwrap();
2742        let std_mse = mse(&std_preds, &y);
2743
2744        // Histogram GBM.
2745        let hist_model = HistGradientBoostingRegressor::<f64>::new()
2746            .with_n_estimators(50)
2747            .with_learning_rate(0.1)
2748            .with_min_samples_leaf(1)
2749            .with_max_leaf_nodes(None)
2750            .with_max_depth(Some(3))
2751            .with_random_state(42);
2752        let hist_fitted = hist_model.fit(&x, &y).unwrap();
2753        let hist_preds = hist_fitted.predict(&x).unwrap();
2754        let hist_mse = mse(&hist_preds, &y);
2755
2756        // Both should have low MSE on this simple task.
2757        assert!(std_mse < 1.0, "Standard GBM MSE too high: {}", std_mse);
2758        assert!(hist_mse < 1.0, "Hist GBM MSE too high: {}", hist_mse);
2759    }
2760}