Skip to main content

irithyll_core/
quantize.rs

1//! f64 → f32 quantization utilities for packed export.
2//!
3//! When converting a trained SGBT (f64 precision) to the packed format (f32),
4//! thresholds and leaf values are quantized. This module provides validation
5//! to ensure the precision loss is acceptable.
6
7/// Maximum acceptable absolute difference between f64 and f32 representations.
8pub const DEFAULT_TOLERANCE: f64 = 1e-5;
9
10/// Quantize an f64 threshold to f32, returning the f32 value.
11#[inline]
12pub fn quantize_threshold(value: f64) -> f32 {
13    value as f32
14}
15
16/// Quantize a leaf value with learning rate baked in.
17///
18/// Returns `lr * leaf_value` as f32.
19#[inline]
20pub fn quantize_leaf(leaf_value: f64, learning_rate: f64) -> f32 {
21    (learning_rate * leaf_value) as f32
22}
23
24/// Check whether quantizing `value` to f32 stays within tolerance.
25///
26/// Returns `true` if `|value - (value as f32) as f64| <= tolerance`.
27#[inline]
28pub fn within_tolerance(value: f64, tolerance: f64) -> bool {
29    let quantized = value as f32;
30    let roundtrip = quantized as f64;
31    let diff = crate::math::abs(value - roundtrip);
32    diff <= tolerance
33}
34
35/// Compute the maximum absolute quantization error across a slice of f64 values.
36pub fn max_quantization_error(values: &[f64]) -> f64 {
37    values
38        .iter()
39        .map(|&v| {
40            let q = v as f32;
41            crate::math::abs(v - q as f64)
42        })
43        .fold(0.0f64, |a, b| if a > b { a } else { b })
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    #[test]
51    fn small_values_quantize_exactly() {
52        // Small integers are exactly representable in f32
53        assert!(within_tolerance(0.0, DEFAULT_TOLERANCE));
54        assert!(within_tolerance(1.0, DEFAULT_TOLERANCE));
55        assert!(within_tolerance(-1.0, DEFAULT_TOLERANCE));
56        assert!(within_tolerance(0.5, DEFAULT_TOLERANCE));
57    }
58
59    #[test]
60    fn typical_thresholds_within_tolerance() {
61        // Typical tree thresholds are small-ish floats
62        let thresholds = [0.001, 0.1, 1.5, 10.0, 100.0, -0.5, -50.0];
63        for &t in &thresholds {
64            assert!(
65                within_tolerance(t, DEFAULT_TOLERANCE),
66                "threshold {} should be within tolerance",
67                t
68            );
69        }
70    }
71
72    #[test]
73    fn quantize_leaf_bakes_in_lr() {
74        let leaf = 2.0;
75        let lr = 0.1;
76        let q = quantize_leaf(leaf, lr);
77        assert!((q - 0.2f32).abs() < 1e-7);
78    }
79
80    #[test]
81    fn max_error_of_empty_slice() {
82        assert_eq!(max_quantization_error(&[]), 0.0);
83    }
84
85    #[test]
86    fn max_error_tracks_worst_case() {
87        let values = [0.0, 1.0, 0.1]; // 0.1 has the worst f32 roundtrip
88        let err = max_quantization_error(&values);
89        assert!(err > 0.0);
90        assert!(err < 1e-7); // 0.1 still very close in f32
91    }
92}