Skip to main content

irithyll_core/histogram/
bins.rs

1//! Feature histogram storage with gradient/hessian sums per bin.
2//!
3//! Each `FeatureHistogram` accumulates gradient, hessian, and count statistics
4//! into discrete bins defined by `BinEdges`. The SoA (Structure of Arrays) layout
5//! keeps each statistic contiguous in memory for cache-friendly sequential scans
6//! during split evaluation.
7
8use alloc::vec;
9use alloc::vec::Vec;
10
11use crate::histogram::BinEdges;
12
13/// Histogram for a single feature: accumulates grad/hess/count per bin.
14/// SoA layout for cache efficiency during split-point evaluation.
15#[derive(Debug, Clone)]
16pub struct FeatureHistogram {
17    /// Gradient sums per bin.
18    ///
19    /// When `decay_scale != 1.0`, these are stored in un-decayed coordinates.
20    /// Multiply by `decay_scale` to recover effective (decayed) values.
21    pub grad_sums: Vec<f64>,
22    /// Hessian sums per bin (same scaling convention as `grad_sums`).
23    pub hess_sums: Vec<f64>,
24    /// Sample counts per bin (never decayed).
25    pub counts: Vec<u64>,
26    /// Bin edges for this feature.
27    pub edges: BinEdges,
28    /// Accumulated decay factor for lazy forward decay.
29    ///
30    /// Instead of decaying all bins on every sample (O(n_bins)), we track
31    /// a running scale factor and store new samples in un-decayed coordinates.
32    /// The decay is materialized (applied to all bins) only when the histogram
33    /// is read -- typically at split evaluation time, every `grace_period` samples.
34    ///
35    /// Invariant: `effective_value[i] = grad_sums[i] * decay_scale`.
36    /// Value is 1.0 when no decay is pending (default, or after [`materialize_decay`]).
37    decay_scale: f64,
38}
39
40impl FeatureHistogram {
41    /// Create a new zeroed histogram with the given bin edges.
42    pub fn new(edges: BinEdges) -> Self {
43        let n = edges.n_bins();
44        Self {
45            grad_sums: vec![0.0; n],
46            hess_sums: vec![0.0; n],
47            counts: vec![0; n],
48            edges,
49            decay_scale: 1.0,
50        }
51    }
52
53    /// Accumulate a single sample into the appropriate bin.
54    ///
55    /// Finds the bin for `value` using the stored edges, then adds the
56    /// gradient, hessian, and increments the count for that bin.
57    #[inline]
58    pub fn accumulate(&mut self, value: f64, gradient: f64, hessian: f64) {
59        let bin = self.edges.find_bin(value);
60        self.grad_sums[bin] += gradient;
61        self.hess_sums[bin] += hessian;
62        self.counts[bin] += 1;
63    }
64
65    /// Sum of all gradient accumulators across bins.
66    ///
67    /// Accounts for any pending lazy decay by multiplying by `decay_scale`.
68    pub fn total_gradient(&self) -> f64 {
69        let raw = {
70            #[cfg(feature = "simd")]
71            {
72                crate::histogram::simd::sum_f64(&self.grad_sums)
73            }
74            #[cfg(not(feature = "simd"))]
75            {
76                self.grad_sums.iter().sum::<f64>()
77            }
78        };
79        raw * self.decay_scale
80    }
81
82    /// Sum of all hessian accumulators across bins.
83    ///
84    /// Accounts for any pending lazy decay by multiplying by `decay_scale`.
85    pub fn total_hessian(&self) -> f64 {
86        let raw = {
87            #[cfg(feature = "simd")]
88            {
89                crate::histogram::simd::sum_f64(&self.hess_sums)
90            }
91            #[cfg(not(feature = "simd"))]
92            {
93                self.hess_sums.iter().sum::<f64>()
94            }
95        };
96        raw * self.decay_scale
97    }
98
99    /// Total sample count across all bins.
100    pub fn total_count(&self) -> u64 {
101        self.counts.iter().sum()
102    }
103
104    /// Number of bins in this histogram.
105    #[inline]
106    pub fn n_bins(&self) -> usize {
107        self.edges.n_bins()
108    }
109
110    /// Accumulate with lazy forward decay: O(1) per sample.
111    ///
112    /// Implements the forward decay scheme (Cormode et al. 2009) using a
113    /// deferred scaling technique. Instead of decaying all bins on every
114    /// sample, we track a running `decay_scale` and store new samples in
115    /// un-decayed coordinates. The actual bin values are materialized only
116    /// when read (see `materialize_decay()`).
117    ///
118    /// Mathematically equivalent to eager decay: a gradient `g` added at
119    /// epoch `t` contributes `g * alpha^(T - t)` at read time `T`.
120    ///
121    /// The `counts` array is NOT decayed -- it tracks the raw sample count
122    /// for the Hoeffding bound computation.
123    #[inline]
124    pub fn accumulate_with_decay(&mut self, value: f64, gradient: f64, hessian: f64, alpha: f64) {
125        self.decay_scale *= alpha;
126        let inv_scale = 1.0 / self.decay_scale;
127        let bin = self.edges.find_bin(value);
128        self.grad_sums[bin] += gradient * inv_scale;
129        self.hess_sums[bin] += hessian * inv_scale;
130        self.counts[bin] += 1;
131
132        // Renormalize when scale underflows to prevent precision loss.
133        // With alpha ≈ 0.986 (half_life = 50), this fires roughly every
134        // 16K samples per leaf -- negligible amortized cost.
135        if self.decay_scale < 1e-100 {
136            self.materialize_decay();
137        }
138    }
139
140    /// Apply the pending decay factor to all bins and reset the scale.
141    ///
142    /// After calling this, `grad_sums` and `hess_sums` contain the true
143    /// decayed values and can be passed directly to split evaluation.
144    /// O(n_bins) -- called at split evaluation time, not per sample.
145    #[inline]
146    pub fn materialize_decay(&mut self) {
147        if (self.decay_scale - 1.0).abs() > f64::EPSILON {
148            for i in 0..self.grad_sums.len() {
149                self.grad_sums[i] *= self.decay_scale;
150                self.hess_sums[i] *= self.decay_scale;
151            }
152            self.decay_scale = 1.0;
153        }
154    }
155
156    /// Reset all accumulators to zero, preserving the bin edges.
157    pub fn reset(&mut self) {
158        self.grad_sums.fill(0.0);
159        self.hess_sums.fill(0.0);
160        self.counts.fill(0);
161        self.decay_scale = 1.0;
162    }
163
164    /// Histogram subtraction trick: `self - child = sibling`.
165    ///
166    /// When building a tree, if you know the parent histogram and one child's
167    /// histogram, you can derive the other child by subtraction rather than
168    /// re-scanning the data. This halves the histogram-building cost at each
169    /// split.
170    ///
171    /// Scale-aware: if either operand has pending lazy decay, the effective
172    /// (decayed) values are used. The result has `decay_scale = 1.0`.
173    ///
174    /// The returned histogram uses edges cloned from `self`.
175    pub fn subtract(&self, child: &FeatureHistogram) -> FeatureHistogram {
176        debug_assert_eq!(
177            self.n_bins(),
178            child.n_bins(),
179            "cannot subtract histograms with different bin counts"
180        );
181        let n = self.n_bins();
182        let s_self = self.decay_scale;
183        let s_child = child.decay_scale;
184
185        // Fast path: both scales are 1.0, use existing SIMD logic directly.
186        let scales_trivial =
187            (s_self - 1.0).abs() <= f64::EPSILON && (s_child - 1.0).abs() <= f64::EPSILON;
188
189        #[cfg(feature = "simd")]
190        {
191            if scales_trivial {
192                let mut grad_sums = vec![0.0; n];
193                let mut hess_sums = vec![0.0; n];
194                let mut counts = vec![0u64; n];
195                crate::histogram::simd::subtract_f64(
196                    &self.grad_sums,
197                    &child.grad_sums,
198                    &mut grad_sums,
199                );
200                crate::histogram::simd::subtract_f64(
201                    &self.hess_sums,
202                    &child.hess_sums,
203                    &mut hess_sums,
204                );
205                crate::histogram::simd::subtract_u64(&self.counts, &child.counts, &mut counts);
206                return FeatureHistogram {
207                    grad_sums,
208                    hess_sums,
209                    counts,
210                    edges: self.edges.clone(),
211                    decay_scale: 1.0,
212                };
213            }
214        }
215
216        // Scale-aware scalar path (also used as non-SIMD fallback).
217        let _ = scales_trivial; // suppress unused warning in non-SIMD builds
218        let mut grad_sums = Vec::with_capacity(n);
219        let mut hess_sums = Vec::with_capacity(n);
220        let mut counts = Vec::with_capacity(n);
221
222        for i in 0..n {
223            grad_sums.push(self.grad_sums[i] * s_self - child.grad_sums[i] * s_child);
224            hess_sums.push(self.hess_sums[i] * s_self - child.hess_sums[i] * s_child);
225            counts.push(self.counts[i].saturating_sub(child.counts[i]));
226        }
227
228        FeatureHistogram {
229            grad_sums,
230            hess_sums,
231            counts,
232            edges: self.edges.clone(),
233            decay_scale: 1.0,
234        }
235    }
236}
237
238/// Collection of histograms for all features at a single leaf node.
239///
240/// During tree construction, each active leaf maintains one `LeafHistograms`
241/// that accumulates statistics from all samples routed to that leaf.
242#[derive(Debug, Clone)]
243pub struct LeafHistograms {
244    pub histograms: Vec<FeatureHistogram>,
245}
246
247impl LeafHistograms {
248    /// Create histograms for all features, one per entry in `edges_per_feature`.
249    pub fn new(edges_per_feature: &[BinEdges]) -> Self {
250        let histograms = edges_per_feature
251            .iter()
252            .map(|edges| FeatureHistogram::new(edges.clone()))
253            .collect();
254        Self { histograms }
255    }
256
257    /// Accumulate a sample across all feature histograms.
258    ///
259    /// `features` must have the same length as the number of histograms.
260    /// Each feature value is accumulated into its corresponding histogram
261    /// with the shared gradient and hessian.
262    pub fn accumulate(&mut self, features: &[f64], gradient: f64, hessian: f64) {
263        debug_assert_eq!(
264            features.len(),
265            self.histograms.len(),
266            "feature count mismatch: got {} features but have {} histograms",
267            features.len(),
268            self.histograms.len(),
269        );
270        for (hist, &value) in self.histograms.iter_mut().zip(features.iter()) {
271            hist.accumulate(value, gradient, hessian);
272        }
273    }
274
275    /// Accumulate a sample with forward decay across all feature histograms.
276    ///
277    /// Each histogram is decayed by `alpha` before accumulating the new sample.
278    pub fn accumulate_with_decay(
279        &mut self,
280        features: &[f64],
281        gradient: f64,
282        hessian: f64,
283        alpha: f64,
284    ) {
285        debug_assert_eq!(
286            features.len(),
287            self.histograms.len(),
288            "feature count mismatch: got {} features but have {} histograms",
289            features.len(),
290            self.histograms.len(),
291        );
292        for (hist, &value) in self.histograms.iter_mut().zip(features.iter()) {
293            hist.accumulate_with_decay(value, gradient, hessian, alpha);
294        }
295    }
296
297    /// Total sample count, taken from the first histogram.
298    ///
299    /// All histograms receive the same samples, so their total counts must agree.
300    /// Returns 0 if there are no features.
301    pub fn total_count(&self) -> u64 {
302        self.histograms.first().map_or(0, |h| h.total_count())
303    }
304
305    /// Number of features (histograms).
306    #[inline]
307    pub fn n_features(&self) -> usize {
308        self.histograms.len()
309    }
310
311    /// Materialize pending lazy decay across all feature histograms.
312    ///
313    /// Call before reading raw bin values (e.g., split evaluation).
314    /// O(n_features * n_bins) -- called at split evaluation time, not per sample.
315    pub fn materialize_decay(&mut self) {
316        for hist in &mut self.histograms {
317            hist.materialize_decay();
318        }
319    }
320
321    /// Reset all histograms to zero, preserving bin edges.
322    pub fn reset(&mut self) {
323        for hist in &mut self.histograms {
324            hist.reset();
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use crate::histogram::BinEdges;
333
334    /// Helper: edges at [2.0, 4.0, 6.0] giving 4 bins:
335    /// (-inf, 2.0], (2.0, 4.0], (4.0, 6.0], (6.0, +inf)
336    fn four_bin_edges() -> BinEdges {
337        BinEdges {
338            edges: vec![2.0, 4.0, 6.0],
339        }
340    }
341
342    #[test]
343    fn feature_histogram_new_is_zeroed() {
344        let h = FeatureHistogram::new(four_bin_edges());
345        assert_eq!(h.n_bins(), 4);
346        assert_eq!(h.total_gradient(), 0.0);
347        assert_eq!(h.total_hessian(), 0.0);
348        assert_eq!(h.total_count(), 0);
349    }
350
351    #[test]
352    fn feature_histogram_accumulate_single() {
353        let mut h = FeatureHistogram::new(four_bin_edges());
354        // value 3.0 -> bin 1 (between edges 2.0 and 4.0)
355        h.accumulate(3.0, 1.5, 0.5);
356        assert_eq!(h.counts[1], 1);
357        assert_eq!(h.grad_sums[1], 1.5);
358        assert_eq!(h.hess_sums[1], 0.5);
359        assert_eq!(h.total_count(), 1);
360    }
361
362    #[test]
363    fn feature_histogram_accumulate_multiple() {
364        let mut h = FeatureHistogram::new(four_bin_edges());
365        // bin 0: values < 2.0
366        h.accumulate(1.0, 0.1, 0.01);
367        h.accumulate(0.5, 0.2, 0.02);
368        // bin 1: 2.0 < values <= 4.0 (value == 2.0 lands in bin 1 because Ok(0) -> 0+1=1)
369        h.accumulate(2.0, 0.3, 0.03);
370        // bin 2: 4.0 < values <= 6.0
371        h.accumulate(5.0, 0.4, 0.04);
372        // bin 3: values > 6.0
373        h.accumulate(7.0, 0.5, 0.05);
374        h.accumulate(8.0, 0.6, 0.06);
375
376        assert_eq!(h.counts, vec![2, 1, 1, 2]);
377        assert_eq!(h.total_count(), 6);
378
379        let total_grad = 0.1 + 0.2 + 0.3 + 0.4 + 0.5 + 0.6;
380        assert!((h.total_gradient() - total_grad).abs() < 1e-12);
381
382        let total_hess = 0.01 + 0.02 + 0.03 + 0.04 + 0.05 + 0.06;
383        assert!((h.total_hessian() - total_hess).abs() < 1e-12);
384    }
385
386    #[test]
387    fn feature_histogram_subtraction_trick() {
388        let edges = four_bin_edges();
389        let mut parent = FeatureHistogram::new(edges.clone());
390        let mut child = FeatureHistogram::new(edges);
391
392        // Populate parent with 6 samples
393        parent.accumulate(1.0, 0.1, 0.01);
394        parent.accumulate(3.0, 0.2, 0.02);
395        parent.accumulate(5.0, 0.3, 0.03);
396        parent.accumulate(7.0, 0.4, 0.04);
397        parent.accumulate(1.5, 0.5, 0.05);
398        parent.accumulate(5.5, 0.6, 0.06);
399
400        // Child gets 3 of those samples
401        child.accumulate(1.0, 0.1, 0.01);
402        child.accumulate(5.0, 0.3, 0.03);
403        child.accumulate(7.0, 0.4, 0.04);
404
405        let sibling = parent.subtract(&child);
406
407        // Sibling should have the remaining 3 samples
408        assert_eq!(sibling.total_count(), 3);
409        // bin 0: parent had 2 (1.0, 1.5), child had 1 (1.0) -> sibling 1
410        assert_eq!(sibling.counts[0], 1);
411        assert!((sibling.grad_sums[0] - 0.5).abs() < 1e-12);
412        assert!((sibling.hess_sums[0] - 0.05).abs() < 1e-12);
413        // bin 1: parent had 1 (3.0), child had 0 -> sibling 1
414        assert_eq!(sibling.counts[1], 1);
415        assert!((sibling.grad_sums[1] - 0.2).abs() < 1e-12);
416        // bin 2: parent had 2 (5.0, 5.5), child had 1 (5.0) -> sibling 1
417        assert_eq!(sibling.counts[2], 1);
418        assert!((sibling.grad_sums[2] - 0.6).abs() < 1e-12);
419        // bin 3: parent had 1 (7.0), child had 1 (7.0) -> sibling 0
420        assert_eq!(sibling.counts[3], 0);
421        assert!((sibling.grad_sums[3]).abs() < 1e-12);
422    }
423
424    #[test]
425    fn feature_histogram_reset() {
426        let mut h = FeatureHistogram::new(four_bin_edges());
427        h.accumulate(1.0, 1.0, 1.0);
428        h.accumulate(3.0, 2.0, 2.0);
429        assert_eq!(h.total_count(), 2);
430
431        h.reset();
432        assert_eq!(h.total_count(), 0);
433        assert_eq!(h.total_gradient(), 0.0);
434        assert_eq!(h.total_hessian(), 0.0);
435        assert_eq!(h.n_bins(), 4); // edges preserved
436    }
437
438    #[test]
439    fn leaf_histograms_multi_feature() {
440        let edges_f0 = BinEdges { edges: vec![5.0] }; // 2 bins
441        let edges_f1 = BinEdges {
442            edges: vec![2.0, 4.0, 6.0],
443        }; // 4 bins
444
445        let mut leaf = LeafHistograms::new(&[edges_f0, edges_f1]);
446        assert_eq!(leaf.n_features(), 2);
447
448        // Sample 1: f0=3.0 (bin 0), f1=5.0 (bin 2)
449        leaf.accumulate(&[3.0, 5.0], 1.0, 0.1);
450        // Sample 2: f0=7.0 (bin 1), f1=1.0 (bin 0)
451        leaf.accumulate(&[7.0, 1.0], 2.0, 0.2);
452        // Sample 3: f0=3.0 (bin 0), f1=3.0 (bin 1)
453        leaf.accumulate(&[3.0, 3.0], 3.0, 0.3);
454
455        assert_eq!(leaf.total_count(), 3);
456
457        // Feature 0 checks
458        let h0 = &leaf.histograms[0];
459        assert_eq!(h0.counts, vec![2, 1]); // two in bin 0, one in bin 1
460        assert!((h0.total_gradient() - 6.0).abs() < 1e-12);
461        assert!((h0.total_hessian() - 0.6).abs() < 1e-12);
462
463        // Feature 1 checks
464        let h1 = &leaf.histograms[1];
465        assert_eq!(h1.counts, vec![1, 1, 1, 0]);
466        assert!((h1.total_gradient() - 6.0).abs() < 1e-12);
467    }
468
469    #[test]
470    fn leaf_histograms_reset() {
471        let edges = BinEdges { edges: vec![3.0] };
472        let mut leaf = LeafHistograms::new(&[edges.clone(), edges]);
473        leaf.accumulate(&[1.0, 5.0], 1.0, 1.0);
474        assert_eq!(leaf.total_count(), 1);
475
476        leaf.reset();
477        assert_eq!(leaf.total_count(), 0);
478        assert_eq!(leaf.n_features(), 2);
479    }
480
481    #[test]
482    fn leaf_histograms_empty() {
483        let leaf = LeafHistograms::new(&[]);
484        assert_eq!(leaf.n_features(), 0);
485        assert_eq!(leaf.total_count(), 0);
486    }
487
488    #[test]
489    fn single_edge_histogram() {
490        // Single edge -> 2 bins: (-inf, 5.0] and (5.0, +inf)
491        let edges = BinEdges { edges: vec![5.0] };
492        let mut h = FeatureHistogram::new(edges);
493        assert_eq!(h.n_bins(), 2);
494
495        h.accumulate(3.0, 1.0, 0.1);
496        h.accumulate(5.0, 2.0, 0.2); // exact edge -> bin 1
497        h.accumulate(7.0, 3.0, 0.3);
498
499        assert_eq!(h.counts, vec![1, 2]);
500        assert!((h.grad_sums[0] - 1.0).abs() < 1e-12);
501        assert!((h.grad_sums[1] - 5.0).abs() < 1e-12);
502    }
503
504    #[test]
505    fn no_edges_single_bin() {
506        // No edges -> 1 bin, everything goes to bin 0
507        let edges = BinEdges { edges: vec![] };
508        let mut h = FeatureHistogram::new(edges);
509        assert_eq!(h.n_bins(), 1);
510
511        h.accumulate(42.0, 1.0, 0.5);
512        h.accumulate(-100.0, 2.0, 0.3);
513
514        assert_eq!(h.counts, vec![2]);
515        assert!((h.total_gradient() - 3.0).abs() < 1e-12);
516        assert!((h.total_hessian() - 0.8).abs() < 1e-12);
517    }
518
519    #[test]
520    fn accumulate_with_decay_recent_dominates() {
521        // 3 bins: edges at [3.0, 7.0] => bins [<3, 3-7, >7]
522        let edges = BinEdges {
523            edges: vec![3.0, 7.0],
524        };
525        let mut h = FeatureHistogram::new(edges);
526        let alpha = 0.9;
527
528        // 100 samples in bin 0 (value=1.0)
529        for _ in 0..100 {
530            h.accumulate_with_decay(1.0, 1.0, 1.0, alpha);
531        }
532
533        // 50 samples in bin 2 (value=8.0) -- enough for recent to dominate
534        for _ in 0..50 {
535            h.accumulate_with_decay(8.0, 1.0, 1.0, alpha);
536        }
537
538        // Materialize to get true decayed values.
539        h.materialize_decay();
540
541        // With alpha=0.9, old bin 0 decays rapidly while bin 2 accumulates.
542        assert!(
543            h.grad_sums[2] > h.grad_sums[0],
544            "recent bin should dominate: bin2={} > bin0={}",
545            h.grad_sums[2],
546            h.grad_sums[0],
547        );
548    }
549
550    #[test]
551    fn lazy_decay_matches_eager() {
552        // Compare lazy decay (new) against a manual eager implementation.
553        // They must produce identical results (within f64 precision).
554        let edges = BinEdges {
555            edges: vec![3.0, 6.0],
556        }; // 3 bins
557        let alpha = 0.95;
558
559        // Lazy histogram (our implementation).
560        let mut lazy = FeatureHistogram::new(edges.clone());
561
562        // Manual eager: track expected values by applying alpha to all bins
563        // each step, then accumulating.
564        let n = 3;
565        let mut eager_grad = vec![0.0; n];
566        let mut eager_hess = vec![0.0; n];
567
568        let samples: Vec<(f64, f64, f64)> = vec![
569            (1.0, 0.5, 1.0),  // bin 0
570            (4.0, -0.3, 0.8), // bin 1
571            (1.0, 0.7, 1.2),  // bin 0
572            (8.0, -1.0, 0.5), // bin 2
573            (5.0, 0.2, 0.9),  // bin 1
574            (1.0, 0.1, 1.1),  // bin 0
575            (8.0, 0.4, 0.6),  // bin 2
576            (4.0, -0.5, 1.0), // bin 1
577        ];
578
579        let edge_vals = [3.0, 6.0];
580        for &(value, gradient, hessian) in &samples {
581            // Eager: decay all bins, then accumulate.
582            for i in 0..n {
583                eager_grad[i] *= alpha;
584                eager_hess[i] *= alpha;
585            }
586            let bin = if value <= edge_vals[0] {
587                0
588            } else if value <= edge_vals[1] {
589                1
590            } else {
591                2
592            };
593            eager_grad[bin] += gradient;
594            eager_hess[bin] += hessian;
595
596            // Lazy: our O(1) implementation.
597            lazy.accumulate_with_decay(value, gradient, hessian, alpha);
598        }
599
600        // Materialize lazy to get comparable values.
601        lazy.materialize_decay();
602
603        for i in 0..n {
604            assert!(
605                (lazy.grad_sums[i] - eager_grad[i]).abs() < 1e-10,
606                "grad_sums[{}]: lazy={}, eager={}",
607                i,
608                lazy.grad_sums[i],
609                eager_grad[i],
610            );
611            assert!(
612                (lazy.hess_sums[i] - eager_hess[i]).abs() < 1e-10,
613                "hess_sums[{}]: lazy={}, eager={}",
614                i,
615                lazy.hess_sums[i],
616                eager_hess[i],
617            );
618        }
619    }
620
621    #[test]
622    fn lazy_decay_total_gradient_without_materialize() {
623        // total_gradient() should return correct decayed value even
624        // WITHOUT explicit materialize_decay() call.
625        let edges = BinEdges { edges: vec![5.0] }; // 2 bins
626        let alpha = 0.9;
627        let mut h = FeatureHistogram::new(edges);
628
629        h.accumulate_with_decay(1.0, 1.0, 1.0, alpha); // bin 0
630        h.accumulate_with_decay(7.0, 1.0, 1.0, alpha); // bin 1
631
632        // Expected: first sample decayed once more = 1.0 * 0.9 = 0.9
633        // Second sample = 1.0, total = 1.9
634        let total = h.total_gradient();
635        assert!(
636            (total - 1.9).abs() < 1e-10,
637            "total_gradient should account for decay_scale: got {}",
638            total,
639        );
640    }
641
642    #[test]
643    fn materialize_is_idempotent() {
644        let edges = BinEdges { edges: vec![5.0] };
645        let mut h = FeatureHistogram::new(edges);
646        let alpha = 0.95;
647
648        for _ in 0..50 {
649            h.accumulate_with_decay(1.0, 1.0, 1.0, alpha);
650        }
651
652        h.materialize_decay();
653        let grad_after_first = h.grad_sums.clone();
654
655        h.materialize_decay(); // second call should be no-op
656        assert_eq!(
657            h.grad_sums, grad_after_first,
658            "second materialize should be a no-op"
659        );
660    }
661
662    #[test]
663    fn lazy_decay_renormalization() {
664        // Use a very aggressive alpha to trigger renormalization quickly.
665        let edges = BinEdges { edges: vec![5.0] };
666        let mut h = FeatureHistogram::new(edges);
667        let alpha = 0.5; // decay_scale halves each step, hits 1e-100 at ~332 steps
668
669        for _ in 0..500 {
670            h.accumulate_with_decay(1.0, 1.0, 1.0, alpha);
671        }
672
673        // Should not have NaN or Inf despite extreme decay.
674        let total = h.total_gradient();
675        assert!(
676            total.is_finite(),
677            "gradient should be finite after renormalization, got {}",
678            total
679        );
680        assert!(total > 0.0, "gradient should be positive, got {}", total);
681
682        // With alpha=0.5, geometric sum converges to 1/(1-0.5) = 2.0
683        assert!(
684            (total - 2.0).abs() < 0.1,
685            "total gradient should converge to ~2.0, got {}",
686            total,
687        );
688    }
689}