nanogbm 0.2.0

A small, pure-Rust gradient boosting library (GBDT, binary classification, CPU only).
use serde::{Deserialize, Serialize};

use super::Objective;

/// Binary logistic objective: probability = sigmoid(raw_score), with
/// gradient = p - y and hessian = p(1 - p).
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BinaryObjective;

#[inline]
pub fn sigmoid(x: f64) -> f64 {
    if x >= 0.0 {
        1.0 / (1.0 + (-x).exp())
    } else {
        let e = x.exp();
        e / (1.0 + e)
    }
}

/// Branchless f32 sigmoid using a polynomial `exp2` approximation.
///
/// Used by the hot-path [`BinaryObjective::gradients_packed`]. The signature
/// is straight-line — no calls to libm, no branches inside the math — so the
/// caller's loop autovectorizes (NEON f32x4 / AVX f32x8). Outputs are written
/// back as f32 anyway, so computing in f32 throughout costs no extra
/// precision at the boundary.
///
/// Accuracy: < 1e-4 absolute error over `x ∈ [-30, 30]`. Beyond ±30 the result
/// saturates to 0.0 / 1.0 — same as the f64 path.
#[inline]
pub fn sigmoid_fast_f32(x: f32) -> f32 {
    // Clamp to bound exp argument: |x| > 30 saturates anyway.
    let x = x.clamp(-30.0, 30.0);

    // Compute exp(-|x|) via 2^(-|x| * log2_e). The 2^y for y ≤ 0 stays in
    // [0, 1] so no overflow.
    let ax = x.abs();
    let y = -ax * std::f32::consts::LOG2_E; // y ∈ [-43.3, 0]

    // Range reduce: 2^y = 2^yi * 2^yf where yi = floor(y), yf ∈ [0, 1).
    let yi = y.floor();
    let yf = y - yi;

    // Minimax degree-5 polynomial for 2^yf over [0, 1).
    // Max error ~1e-7 — way below our 1e-4 target. The first coefficient is
    // ln(2); we use the std constant to keep clippy happy.
    let p = 1.0
        + yf
            * (std::f32::consts::LN_2
                + yf * (0.2402265 + yf * (0.0555041 + yf * (0.0096181 + yf * 0.0013427))));

    // 2^yi via bit injection into the f32 exponent (IEEE-754: exponent =
    // (yi + 127) << 23). Clamp yi range to keep the bit math safe.
    let yi_i = (yi as i32).clamp(-126, 127);
    let pow2_yi = f32::from_bits(((yi_i + 127) as u32) << 23);

    let e = pow2_yi * p; // exp(-|x|), in (0, 1]

    // sigmoid(x) = 1/(1+e^-x); using exp(-|x|):
    //   x >= 0 → 1 / (1 + e)
    //   x  < 0 → e / (1 + e)
    let s_neg = e / (1.0 + e);
    let s_pos = 1.0 - s_neg;
    // Branchless select via sign bit. f32::copysign avoids a real branch and
    // the compiler turns the whole `if` into a `fcsel` on aarch64.
    if x >= 0.0 { s_pos } else { s_neg }
}

impl Objective for BinaryObjective {
    fn init_score(&self, labels: &[f32]) -> f64 {
        let n = labels.len() as f64;
        let pos: f64 = labels.iter().map(|&y| y as f64).sum();
        let mean = (pos / n).clamp(1e-6, 1.0 - 1e-6);
        (mean / (1.0 - mean)).ln()
    }

    fn convert_output(&self, raw_scores: &[f64], out: &mut [f64]) {
        for (r, o) in raw_scores.iter().zip(out.iter_mut()) {
            *o = sigmoid(*r);
        }
    }

    fn gradients(&self, raw_scores: &[f64], labels: &[f32], grads: &mut [f32], hesss: &mut [f32]) {
        for i in 0..raw_scores.len() {
            let p = sigmoid(raw_scores[i]);
            grads[i] = (p - labels[i] as f64) as f32;
            hesss[i] = (p * (1.0 - p)).max(1e-6) as f32;
        }
    }

    /// Pack-aware variant: write `[grad, hess]` pairs into a single buffer so
    /// the histogram-build hot loop can fetch both values in one memory access
    /// (8-byte aligned), halving the per-row gather count vs separate arrays.
    ///
    /// Uses [`sigmoid_fast_f32`] internally — straight-line, branchless math
    /// that LLVM autovectorizes into NEON f32x4 / AVX f32x8. Computing in f32
    /// is precision-equivalent because the output is f32 anyway.
    fn gradients_packed(&self, raw_scores: &[f64], labels: &[f32], out: &mut [[f32; 2]]) {
        let n = raw_scores.len();
        for i in 0..n {
            let s = raw_scores[i] as f32;
            let p = sigmoid_fast_f32(s);
            let y = labels[i];
            out[i][0] = p - y;
            let h = p * (1.0 - p);
            out[i][1] = if h > 1e-6 { h } else { 1e-6 };
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn fast_sigmoid_matches_libm_within_tolerance() {
        for i in -300i32..=300 {
            let x = i as f32 * 0.1;
            let fast = sigmoid_fast_f32(x) as f64;
            let exact = sigmoid(x as f64);
            let abs_err = (fast - exact).abs();
            assert!(
                abs_err < 1e-4,
                "x={x} fast={fast} exact={exact} err={abs_err}"
            );
        }
    }

    #[test]
    fn fast_sigmoid_saturates_at_extremes() {
        assert!((sigmoid_fast_f32(50.0) - 1.0).abs() < 1e-6);
        assert!(sigmoid_fast_f32(-50.0).abs() < 1e-6);
    }
}