Skip to main content

irithyll_core/histogram/
bins.rs

1//! Feature histogram storage with gradient/hessian sums per bin.
2//!
3//! Each `FeatureHistogram` accumulates gradient, hessian, and count statistics
4//! into discrete bins defined by `BinEdges`. The SoA (Structure of Arrays) layout
5//! keeps each statistic contiguous in memory for cache-friendly sequential scans
6//! during split evaluation.
7
8use alloc::vec;
9use alloc::vec::Vec;
10
11use crate::histogram::BinEdges;
12
13/// Histogram for a single feature: accumulates grad/hess/count per bin.
14/// SoA layout for cache efficiency during split-point evaluation.
15#[derive(Debug, Clone)]
16pub struct FeatureHistogram {
17    /// Gradient sums per bin.
18    ///
19    /// When `decay_scale != 1.0`, these are stored in un-decayed coordinates.
20    /// Multiply by `decay_scale` to recover effective (decayed) values.
21    pub grad_sums: Vec<f64>,
22    /// Hessian sums per bin (same scaling convention as `grad_sums`).
23    pub hess_sums: Vec<f64>,
24    /// Sample counts per bin (never decayed).
25    pub counts: Vec<u64>,
26    /// Bin edges for this feature.
27    pub edges: BinEdges,
28    /// Accumulated decay factor for lazy forward decay.
29    ///
30    /// Instead of decaying all bins on every sample (O(n_bins)), we track
31    /// a running scale factor and store new samples in un-decayed coordinates.
32    /// The decay is materialized (applied to all bins) only when the histogram
33    /// is read -- typically at split evaluation time, every `grace_period` samples.
34    ///
35    /// Invariant: `effective_value[i] = grad_sums[i] * decay_scale`.
36    /// Value is 1.0 when no decay is pending (default, or after [`materialize_decay`]).
37    decay_scale: f64,
38}
39
40impl FeatureHistogram {
41    /// Create a new zeroed histogram with the given bin edges.
42    pub fn new(edges: BinEdges) -> Self {
43        let n = edges.n_bins();
44        Self {
45            grad_sums: vec![0.0; n],
46            hess_sums: vec![0.0; n],
47            counts: vec![0; n],
48            edges,
49            decay_scale: 1.0,
50        }
51    }
52
53    /// Accumulate a single sample into the appropriate bin.
54    ///
55    /// Finds the bin for `value` using the stored edges, then adds the
56    /// gradient, hessian, and increments the count for that bin.
57    #[inline]
58    pub fn accumulate(&mut self, value: f64, gradient: f64, hessian: f64) {
59        let bin = self.edges.find_bin(value);
60        self.grad_sums[bin] += gradient;
61        self.hess_sums[bin] += hessian;
62        self.counts[bin] += 1;
63    }
64
65    /// Sum of all gradient accumulators across bins.
66    ///
67    /// Accounts for any pending lazy decay by multiplying by `decay_scale`.
68    pub fn total_gradient(&self) -> f64 {
69        let raw = {
70            #[cfg(feature = "simd")]
71            {
72                crate::histogram::simd::sum_f64(&self.grad_sums)
73            }
74            #[cfg(not(feature = "simd"))]
75            {
76                self.grad_sums.iter().sum::<f64>()
77            }
78        };
79        raw * self.decay_scale
80    }
81
82    /// Sum of all hessian accumulators across bins.
83    ///
84    /// Accounts for any pending lazy decay by multiplying by `decay_scale`.
85    pub fn total_hessian(&self) -> f64 {
86        let raw = {
87            #[cfg(feature = "simd")]
88            {
89                crate::histogram::simd::sum_f64(&self.hess_sums)
90            }
91            #[cfg(not(feature = "simd"))]
92            {
93                self.hess_sums.iter().sum::<f64>()
94            }
95        };
96        raw * self.decay_scale
97    }
98
99    /// Total sample count across all bins.
100    pub fn total_count(&self) -> u64 {
101        self.counts.iter().sum()
102    }
103
104    /// Number of bins in this histogram.
105    #[inline]
106    pub fn n_bins(&self) -> usize {
107        self.edges.n_bins()
108    }
109
110    /// Accumulate with lazy forward decay: O(1) per sample.
111    ///
112    /// Implements the forward decay scheme (Cormode et al. 2009) using a
113    /// deferred scaling technique. Instead of decaying all bins on every
114    /// sample, we track a running `decay_scale` and store new samples in
115    /// un-decayed coordinates. The actual bin values are materialized only
116    /// when read (see `materialize_decay()`).
117    ///
118    /// Mathematically equivalent to eager decay: a gradient `g` added at
119    /// epoch `t` contributes `g * alpha^(T - t)` at read time `T`.
120    ///
121    /// The `counts` array is NOT decayed -- it tracks the raw sample count
122    /// for the Hoeffding bound computation.
123    #[inline]
124    pub fn accumulate_with_decay(&mut self, value: f64, gradient: f64, hessian: f64, alpha: f64) {
125        self.decay_scale *= alpha;
126        let inv_scale = 1.0 / self.decay_scale;
127        let bin = self.edges.find_bin(value);
128        self.grad_sums[bin] += gradient * inv_scale;
129        self.hess_sums[bin] += hessian * inv_scale;
130        self.counts[bin] += 1;
131
132        // Renormalize when scale underflows to prevent precision loss.
133        // With alpha ≈ 0.986 (half_life = 50), this fires roughly every
134        // 16K samples per leaf -- negligible amortized cost.
135        if self.decay_scale < 1e-100 {
136            self.materialize_decay();
137        }
138    }
139
140    /// Apply the pending decay factor to all bins and reset the scale.
141    ///
142    /// After calling this, `grad_sums` and `hess_sums` contain the true
143    /// decayed values and can be passed directly to split evaluation.
144    /// O(n_bins) -- called at split evaluation time, not per sample.
145    #[inline]
146    pub fn materialize_decay(&mut self) {
147        if (self.decay_scale - 1.0).abs() > f64::EPSILON {
148            for i in 0..self.grad_sums.len() {
149                self.grad_sums[i] *= self.decay_scale;
150                self.hess_sums[i] *= self.decay_scale;
151            }
152            self.decay_scale = 1.0;
153        }
154    }
155
156    /// Reset all accumulators to zero, preserving the bin edges.
157    pub fn reset(&mut self) {
158        self.grad_sums.fill(0.0);
159        self.hess_sums.fill(0.0);
160        self.counts.fill(0);
161        self.decay_scale = 1.0;
162    }
163
164    /// Histogram subtraction trick: `self - child = sibling`.
165    ///
166    /// When building a tree, if you know the parent histogram and one child's
167    /// histogram, you can derive the other child by subtraction rather than
168    /// re-scanning the data. This halves the histogram-building cost at each
169    /// split.
170    ///
171    /// Scale-aware: if either operand has pending lazy decay, the effective
172    /// (decayed) values are used. The result has `decay_scale = 1.0`.
173    ///
174    /// The returned histogram uses edges cloned from `self`.
175    pub fn subtract(&self, child: &FeatureHistogram) -> FeatureHistogram {
176        debug_assert_eq!(
177            self.n_bins(),
178            child.n_bins(),
179            "cannot subtract histograms with different bin counts"
180        );
181        let n = self.n_bins();
182        let s_self = self.decay_scale;
183        let s_child = child.decay_scale;
184
185        // Fast path: both scales are 1.0, use existing SIMD logic directly.
186        let scales_trivial =
187            (s_self - 1.0).abs() <= f64::EPSILON && (s_child - 1.0).abs() <= f64::EPSILON;
188
189        #[cfg(feature = "simd")]
190        {
191            if scales_trivial {
192                let mut grad_sums = vec![0.0; n];
193                let mut hess_sums = vec![0.0; n];
194                let mut counts = vec![0u64; n];
195                crate::histogram::simd::subtract_f64(
196                    &self.grad_sums,
197                    &child.grad_sums,
198                    &mut grad_sums,
199                );
200                crate::histogram::simd::subtract_f64(
201                    &self.hess_sums,
202                    &child.hess_sums,
203                    &mut hess_sums,
204                );
205                crate::histogram::simd::subtract_u64(&self.counts, &child.counts, &mut counts);
206                return FeatureHistogram {
207                    grad_sums,
208                    hess_sums,
209                    counts,
210                    edges: self.edges.clone(),
211                    decay_scale: 1.0,
212                };
213            }
214        }
215
216        // Scale-aware scalar path (also used as non-SIMD fallback).
217        let _ = scales_trivial; // suppress unused warning in non-SIMD builds
218        let mut grad_sums = Vec::with_capacity(n);
219        let mut hess_sums = Vec::with_capacity(n);
220        let mut counts = Vec::with_capacity(n);
221
222        for i in 0..n {
223            grad_sums.push(self.grad_sums[i] * s_self - child.grad_sums[i] * s_child);
224            hess_sums.push(self.hess_sums[i] * s_self - child.hess_sums[i] * s_child);
225            counts.push(self.counts[i].saturating_sub(child.counts[i]));
226        }
227
228        FeatureHistogram {
229            grad_sums,
230            hess_sums,
231            counts,
232            edges: self.edges.clone(),
233            decay_scale: 1.0,
234        }
235    }
236}
237
238/// Collection of histograms for all features at a single leaf node.
239///
240/// During tree construction, each active leaf maintains one `LeafHistograms`
241/// that accumulates statistics from all samples routed to that leaf.
242#[derive(Debug, Clone)]
243pub struct LeafHistograms {
244    /// Per-feature histograms (one per feature column).
245    pub histograms: Vec<FeatureHistogram>,
246}
247
248impl LeafHistograms {
249    /// Create histograms for all features, one per entry in `edges_per_feature`.
250    pub fn new(edges_per_feature: &[BinEdges]) -> Self {
251        let histograms = edges_per_feature
252            .iter()
253            .map(|edges| FeatureHistogram::new(edges.clone()))
254            .collect();
255        Self { histograms }
256    }
257
258    /// Accumulate a sample across all feature histograms.
259    ///
260    /// `features` must have the same length as the number of histograms.
261    /// Each feature value is accumulated into its corresponding histogram
262    /// with the shared gradient and hessian.
263    pub fn accumulate(&mut self, features: &[f64], gradient: f64, hessian: f64) {
264        debug_assert_eq!(
265            features.len(),
266            self.histograms.len(),
267            "feature count mismatch: got {} features but have {} histograms",
268            features.len(),
269            self.histograms.len(),
270        );
271        for (hist, &value) in self.histograms.iter_mut().zip(features.iter()) {
272            hist.accumulate(value, gradient, hessian);
273        }
274    }
275
276    /// Accumulate a sample with forward decay across all feature histograms.
277    ///
278    /// Each histogram is decayed by `alpha` before accumulating the new sample.
279    pub fn accumulate_with_decay(
280        &mut self,
281        features: &[f64],
282        gradient: f64,
283        hessian: f64,
284        alpha: f64,
285    ) {
286        debug_assert_eq!(
287            features.len(),
288            self.histograms.len(),
289            "feature count mismatch: got {} features but have {} histograms",
290            features.len(),
291            self.histograms.len(),
292        );
293        for (hist, &value) in self.histograms.iter_mut().zip(features.iter()) {
294            hist.accumulate_with_decay(value, gradient, hessian, alpha);
295        }
296    }
297
298    /// Total sample count, taken from the first histogram.
299    ///
300    /// All histograms receive the same samples, so their total counts must agree.
301    /// Returns 0 if there are no features.
302    pub fn total_count(&self) -> u64 {
303        self.histograms.first().map_or(0, |h| h.total_count())
304    }
305
306    /// Number of features (histograms).
307    #[inline]
308    pub fn n_features(&self) -> usize {
309        self.histograms.len()
310    }
311
312    /// Materialize pending lazy decay across all feature histograms.
313    ///
314    /// Call before reading raw bin values (e.g., split evaluation).
315    /// O(n_features * n_bins) -- called at split evaluation time, not per sample.
316    pub fn materialize_decay(&mut self) {
317        for hist in &mut self.histograms {
318            hist.materialize_decay();
319        }
320    }
321
322    /// Reset all histograms to zero, preserving bin edges.
323    pub fn reset(&mut self) {
324        for hist in &mut self.histograms {
325            hist.reset();
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::histogram::BinEdges;
334
335    /// Helper: edges at [2.0, 4.0, 6.0] giving 4 bins:
336    /// (-inf, 2.0], (2.0, 4.0], (4.0, 6.0], (6.0, +inf)
337    fn four_bin_edges() -> BinEdges {
338        BinEdges {
339            edges: vec![2.0, 4.0, 6.0],
340        }
341    }
342
343    #[test]
344    fn feature_histogram_new_is_zeroed() {
345        let h = FeatureHistogram::new(four_bin_edges());
346        assert_eq!(h.n_bins(), 4);
347        assert_eq!(h.total_gradient(), 0.0);
348        assert_eq!(h.total_hessian(), 0.0);
349        assert_eq!(h.total_count(), 0);
350    }
351
352    #[test]
353    fn feature_histogram_accumulate_single() {
354        let mut h = FeatureHistogram::new(four_bin_edges());
355        // value 3.0 -> bin 1 (between edges 2.0 and 4.0)
356        h.accumulate(3.0, 1.5, 0.5);
357        assert_eq!(h.counts[1], 1);
358        assert_eq!(h.grad_sums[1], 1.5);
359        assert_eq!(h.hess_sums[1], 0.5);
360        assert_eq!(h.total_count(), 1);
361    }
362
363    #[test]
364    fn feature_histogram_accumulate_multiple() {
365        let mut h = FeatureHistogram::new(four_bin_edges());
366        // bin 0: values < 2.0
367        h.accumulate(1.0, 0.1, 0.01);
368        h.accumulate(0.5, 0.2, 0.02);
369        // bin 1: 2.0 < values <= 4.0 (value == 2.0 lands in bin 1 because Ok(0) -> 0+1=1)
370        h.accumulate(2.0, 0.3, 0.03);
371        // bin 2: 4.0 < values <= 6.0
372        h.accumulate(5.0, 0.4, 0.04);
373        // bin 3: values > 6.0
374        h.accumulate(7.0, 0.5, 0.05);
375        h.accumulate(8.0, 0.6, 0.06);
376
377        assert_eq!(h.counts, vec![2, 1, 1, 2]);
378        assert_eq!(h.total_count(), 6);
379
380        let total_grad = 0.1 + 0.2 + 0.3 + 0.4 + 0.5 + 0.6;
381        assert!((h.total_gradient() - total_grad).abs() < 1e-12);
382
383        let total_hess = 0.01 + 0.02 + 0.03 + 0.04 + 0.05 + 0.06;
384        assert!((h.total_hessian() - total_hess).abs() < 1e-12);
385    }
386
387    #[test]
388    fn feature_histogram_subtraction_trick() {
389        let edges = four_bin_edges();
390        let mut parent = FeatureHistogram::new(edges.clone());
391        let mut child = FeatureHistogram::new(edges);
392
393        // Populate parent with 6 samples
394        parent.accumulate(1.0, 0.1, 0.01);
395        parent.accumulate(3.0, 0.2, 0.02);
396        parent.accumulate(5.0, 0.3, 0.03);
397        parent.accumulate(7.0, 0.4, 0.04);
398        parent.accumulate(1.5, 0.5, 0.05);
399        parent.accumulate(5.5, 0.6, 0.06);
400
401        // Child gets 3 of those samples
402        child.accumulate(1.0, 0.1, 0.01);
403        child.accumulate(5.0, 0.3, 0.03);
404        child.accumulate(7.0, 0.4, 0.04);
405
406        let sibling = parent.subtract(&child);
407
408        // Sibling should have the remaining 3 samples
409        assert_eq!(sibling.total_count(), 3);
410        // bin 0: parent had 2 (1.0, 1.5), child had 1 (1.0) -> sibling 1
411        assert_eq!(sibling.counts[0], 1);
412        assert!((sibling.grad_sums[0] - 0.5).abs() < 1e-12);
413        assert!((sibling.hess_sums[0] - 0.05).abs() < 1e-12);
414        // bin 1: parent had 1 (3.0), child had 0 -> sibling 1
415        assert_eq!(sibling.counts[1], 1);
416        assert!((sibling.grad_sums[1] - 0.2).abs() < 1e-12);
417        // bin 2: parent had 2 (5.0, 5.5), child had 1 (5.0) -> sibling 1
418        assert_eq!(sibling.counts[2], 1);
419        assert!((sibling.grad_sums[2] - 0.6).abs() < 1e-12);
420        // bin 3: parent had 1 (7.0), child had 1 (7.0) -> sibling 0
421        assert_eq!(sibling.counts[3], 0);
422        assert!((sibling.grad_sums[3]).abs() < 1e-12);
423    }
424
425    #[test]
426    fn feature_histogram_reset() {
427        let mut h = FeatureHistogram::new(four_bin_edges());
428        h.accumulate(1.0, 1.0, 1.0);
429        h.accumulate(3.0, 2.0, 2.0);
430        assert_eq!(h.total_count(), 2);
431
432        h.reset();
433        assert_eq!(h.total_count(), 0);
434        assert_eq!(h.total_gradient(), 0.0);
435        assert_eq!(h.total_hessian(), 0.0);
436        assert_eq!(h.n_bins(), 4); // edges preserved
437    }
438
439    #[test]
440    fn leaf_histograms_multi_feature() {
441        let edges_f0 = BinEdges { edges: vec![5.0] }; // 2 bins
442        let edges_f1 = BinEdges {
443            edges: vec![2.0, 4.0, 6.0],
444        }; // 4 bins
445
446        let mut leaf = LeafHistograms::new(&[edges_f0, edges_f1]);
447        assert_eq!(leaf.n_features(), 2);
448
449        // Sample 1: f0=3.0 (bin 0), f1=5.0 (bin 2)
450        leaf.accumulate(&[3.0, 5.0], 1.0, 0.1);
451        // Sample 2: f0=7.0 (bin 1), f1=1.0 (bin 0)
452        leaf.accumulate(&[7.0, 1.0], 2.0, 0.2);
453        // Sample 3: f0=3.0 (bin 0), f1=3.0 (bin 1)
454        leaf.accumulate(&[3.0, 3.0], 3.0, 0.3);
455
456        assert_eq!(leaf.total_count(), 3);
457
458        // Feature 0 checks
459        let h0 = &leaf.histograms[0];
460        assert_eq!(h0.counts, vec![2, 1]); // two in bin 0, one in bin 1
461        assert!((h0.total_gradient() - 6.0).abs() < 1e-12);
462        assert!((h0.total_hessian() - 0.6).abs() < 1e-12);
463
464        // Feature 1 checks
465        let h1 = &leaf.histograms[1];
466        assert_eq!(h1.counts, vec![1, 1, 1, 0]);
467        assert!((h1.total_gradient() - 6.0).abs() < 1e-12);
468    }
469
470    #[test]
471    fn leaf_histograms_reset() {
472        let edges = BinEdges { edges: vec![3.0] };
473        let mut leaf = LeafHistograms::new(&[edges.clone(), edges]);
474        leaf.accumulate(&[1.0, 5.0], 1.0, 1.0);
475        assert_eq!(leaf.total_count(), 1);
476
477        leaf.reset();
478        assert_eq!(leaf.total_count(), 0);
479        assert_eq!(leaf.n_features(), 2);
480    }
481
482    #[test]
483    fn leaf_histograms_empty() {
484        let leaf = LeafHistograms::new(&[]);
485        assert_eq!(leaf.n_features(), 0);
486        assert_eq!(leaf.total_count(), 0);
487    }
488
489    #[test]
490    fn single_edge_histogram() {
491        // Single edge -> 2 bins: (-inf, 5.0] and (5.0, +inf)
492        let edges = BinEdges { edges: vec![5.0] };
493        let mut h = FeatureHistogram::new(edges);
494        assert_eq!(h.n_bins(), 2);
495
496        h.accumulate(3.0, 1.0, 0.1);
497        h.accumulate(5.0, 2.0, 0.2); // exact edge -> bin 1
498        h.accumulate(7.0, 3.0, 0.3);
499
500        assert_eq!(h.counts, vec![1, 2]);
501        assert!((h.grad_sums[0] - 1.0).abs() < 1e-12);
502        assert!((h.grad_sums[1] - 5.0).abs() < 1e-12);
503    }
504
505    #[test]
506    fn no_edges_single_bin() {
507        // No edges -> 1 bin, everything goes to bin 0
508        let edges = BinEdges { edges: vec![] };
509        let mut h = FeatureHistogram::new(edges);
510        assert_eq!(h.n_bins(), 1);
511
512        h.accumulate(42.0, 1.0, 0.5);
513        h.accumulate(-100.0, 2.0, 0.3);
514
515        assert_eq!(h.counts, vec![2]);
516        assert!((h.total_gradient() - 3.0).abs() < 1e-12);
517        assert!((h.total_hessian() - 0.8).abs() < 1e-12);
518    }
519
520    #[test]
521    fn accumulate_with_decay_recent_dominates() {
522        // 3 bins: edges at [3.0, 7.0] => bins [<3, 3-7, >7]
523        let edges = BinEdges {
524            edges: vec![3.0, 7.0],
525        };
526        let mut h = FeatureHistogram::new(edges);
527        let alpha = 0.9;
528
529        // 100 samples in bin 0 (value=1.0)
530        for _ in 0..100 {
531            h.accumulate_with_decay(1.0, 1.0, 1.0, alpha);
532        }
533
534        // 50 samples in bin 2 (value=8.0) -- enough for recent to dominate
535        for _ in 0..50 {
536            h.accumulate_with_decay(8.0, 1.0, 1.0, alpha);
537        }
538
539        // Materialize to get true decayed values.
540        h.materialize_decay();
541
542        // With alpha=0.9, old bin 0 decays rapidly while bin 2 accumulates.
543        assert!(
544            h.grad_sums[2] > h.grad_sums[0],
545            "recent bin should dominate: bin2={} > bin0={}",
546            h.grad_sums[2],
547            h.grad_sums[0],
548        );
549    }
550
551    #[test]
552    fn lazy_decay_matches_eager() {
553        // Compare lazy decay (new) against a manual eager implementation.
554        // They must produce identical results (within f64 precision).
555        let edges = BinEdges {
556            edges: vec![3.0, 6.0],
557        }; // 3 bins
558        let alpha = 0.95;
559
560        // Lazy histogram (our implementation).
561        let mut lazy = FeatureHistogram::new(edges.clone());
562
563        // Manual eager: track expected values by applying alpha to all bins
564        // each step, then accumulating.
565        let n = 3;
566        let mut eager_grad = vec![0.0; n];
567        let mut eager_hess = vec![0.0; n];
568
569        let samples: Vec<(f64, f64, f64)> = vec![
570            (1.0, 0.5, 1.0),  // bin 0
571            (4.0, -0.3, 0.8), // bin 1
572            (1.0, 0.7, 1.2),  // bin 0
573            (8.0, -1.0, 0.5), // bin 2
574            (5.0, 0.2, 0.9),  // bin 1
575            (1.0, 0.1, 1.1),  // bin 0
576            (8.0, 0.4, 0.6),  // bin 2
577            (4.0, -0.5, 1.0), // bin 1
578        ];
579
580        let edge_vals = [3.0, 6.0];
581        for &(value, gradient, hessian) in &samples {
582            // Eager: decay all bins, then accumulate.
583            for i in 0..n {
584                eager_grad[i] *= alpha;
585                eager_hess[i] *= alpha;
586            }
587            let bin = if value <= edge_vals[0] {
588                0
589            } else if value <= edge_vals[1] {
590                1
591            } else {
592                2
593            };
594            eager_grad[bin] += gradient;
595            eager_hess[bin] += hessian;
596
597            // Lazy: our O(1) implementation.
598            lazy.accumulate_with_decay(value, gradient, hessian, alpha);
599        }
600
601        // Materialize lazy to get comparable values.
602        lazy.materialize_decay();
603
604        for i in 0..n {
605            assert!(
606                (lazy.grad_sums[i] - eager_grad[i]).abs() < 1e-10,
607                "grad_sums[{}]: lazy={}, eager={}",
608                i,
609                lazy.grad_sums[i],
610                eager_grad[i],
611            );
612            assert!(
613                (lazy.hess_sums[i] - eager_hess[i]).abs() < 1e-10,
614                "hess_sums[{}]: lazy={}, eager={}",
615                i,
616                lazy.hess_sums[i],
617                eager_hess[i],
618            );
619        }
620    }
621
622    #[test]
623    fn lazy_decay_total_gradient_without_materialize() {
624        // total_gradient() should return correct decayed value even
625        // WITHOUT explicit materialize_decay() call.
626        let edges = BinEdges { edges: vec![5.0] }; // 2 bins
627        let alpha = 0.9;
628        let mut h = FeatureHistogram::new(edges);
629
630        h.accumulate_with_decay(1.0, 1.0, 1.0, alpha); // bin 0
631        h.accumulate_with_decay(7.0, 1.0, 1.0, alpha); // bin 1
632
633        // Expected: first sample decayed once more = 1.0 * 0.9 = 0.9
634        // Second sample = 1.0, total = 1.9
635        let total = h.total_gradient();
636        assert!(
637            (total - 1.9).abs() < 1e-10,
638            "total_gradient should account for decay_scale: got {}",
639            total,
640        );
641    }
642
643    #[test]
644    fn materialize_is_idempotent() {
645        let edges = BinEdges { edges: vec![5.0] };
646        let mut h = FeatureHistogram::new(edges);
647        let alpha = 0.95;
648
649        for _ in 0..50 {
650            h.accumulate_with_decay(1.0, 1.0, 1.0, alpha);
651        }
652
653        h.materialize_decay();
654        let grad_after_first = h.grad_sums.clone();
655
656        h.materialize_decay(); // second call should be no-op
657        assert_eq!(
658            h.grad_sums, grad_after_first,
659            "second materialize should be a no-op"
660        );
661    }
662
663    #[test]
664    fn lazy_decay_renormalization() {
665        // Use a very aggressive alpha to trigger renormalization quickly.
666        let edges = BinEdges { edges: vec![5.0] };
667        let mut h = FeatureHistogram::new(edges);
668        let alpha = 0.5; // decay_scale halves each step, hits 1e-100 at ~332 steps
669
670        for _ in 0..500 {
671            h.accumulate_with_decay(1.0, 1.0, 1.0, alpha);
672        }
673
674        // Should not have NaN or Inf despite extreme decay.
675        let total = h.total_gradient();
676        assert!(
677            total.is_finite(),
678            "gradient should be finite after renormalization, got {}",
679            total
680        );
681        assert!(total > 0.0, "gradient should be positive, got {}", total);
682
683        // With alpha=0.5, geometric sum converges to 1/(1-0.5) = 2.0
684        assert!(
685            (total - 2.0).abs() < 0.1,
686            "total gradient should converge to ~2.0, got {}",
687            total,
688        );
689    }
690}