ferrotorch-core 0.6.0

Core tensor and autograd engine for ferrotorch — PyTorch in Rust
Documentation
//! Torch-matching f32 reduction primitives.
//!
//! PyTorch's CPU f32 L2-norm reduction is NOT a naive scalar `Σ v*v`. It is a
//! width-8 lane-grouped accumulation (one `Vectorized<float>` of 8 AVX2 lanes)
//! followed by a specific horizontal fold and a scalar remainder, then `sqrt`.
//! A naive scalar `Σ v.abs().powf(2.0)` (or even a scalar `Σ v*v`) gives a
//! result that differs from torch by one ULP on a meaningful fraction of f32
//! rows, which flips boundary decisions (e.g. the `embedding(max_norm=...)`
//! renorm-or-not decision — #1612 / #1614). This module ships the reduction
//! torch actually performs so those boundary decisions match byte-for-byte.
//!
//! ## REQ status (per `.design/ferrotorch-core/simd_reduce.md`)
//!
//! | REQ | Status | Evidence |
//! |---|---|---|
//! | REQ-1 (`l2_norm_f32_torch`) | SHIPPED | `pub fn l2_norm_f32_torch` here models torch's vectorized last-dim L2 kernel (`aten/src/ATen/native/cpu/ReduceOpsKernel.cpp:222-255`): 8 f32 lanes accumulate `lanes[j] += data[d+j]*data[d+j]` (plain mul+add, mirroring AVX2 `_mm256_mul_ps`+`_mm256_add_ps` at `vec256/vec256_float.h:564`), a naive left-fold `buffer[0] += buffer[j]`, a scalar FMA tail `buffer[0] = fma(data, data, buffer[0])` (mirroring the compiled contraction of `buffer[0] += data*data` at `ReduceOpsKernel.cpp:251`), then `sqrt`. Non-test production consumers: `ferrotorch_core::grad_fns::reduction::norm_with_dim` (the `p==2.0`, `T==f32`, last-dim-contiguous slice) and `ferrotorch_nn::embedding::renorm_weight_rows_in_place` (the `norm_type==2.0`, `T==f32` renorm decision). |
//!
//! ## Why this is not byte-exact for 100% of rows (honest scope, R-HONEST-3)
//!
//! Across a 400-row live-torch oracle (this AVX2 host, lengths 1..65), this
//! primitive matches torch's `at::norm(2.0)` f32 bits on ~97% of rows; the
//! current scalar-`powf` path matched ~79%. The residual ~3% are one-ULP
//! misses at certain non-multiple-of-8 lengths where torch's compiled scalar
//! remainder contracts FMA in a pattern a portable Rust loop cannot reproduce
//! exactly. The #1612/#1614 boundary row IS in the matching set, so the renorm
//! decision it pins now matches torch. This is a strict improvement over the
//! `powf` path, not a regression — and the parity-sweep envelope (atol 1e-7)
//! tolerates the residual sub-ULP differences on non-boundary rows.

/// Number of `f32` lanes torch's `Vectorized<float>` holds on AVX2.
///
/// This host is AVX2 (width-8), no AVX512 (width-16). The accumulation tree is
/// width-dependent, so we model exactly the width-8 structure torch's compiled
/// kernel uses here. See `aten/src/ATen/cpu/vec/vec256/vec256_float.h`
/// (`Vectorized<float>::size() == 8`).
const F32_LANES: usize = 8;

/// Compute the L2 (Euclidean) norm of an f32 slice the way PyTorch's CPU
/// reduction does, so the result matches `at::norm(2.0)` on an f32 contiguous
/// last-dim reduction byte-for-byte (on this AVX2 host, modulo the ~3% one-ULP
/// residual documented at the module level).
///
/// This mirrors the vectorized last-dim L2 kernel at
/// `aten/src/ATen/native/cpu/ReduceOpsKernel.cpp:222-255`:
///
/// ```text
///   fVec acc_vec{acc_t(0)};                       // 8 lanes, all zero
///   acc_t buffer[fVec::size()];                   // [f32; 8]
///   for (; d < size - (size % 8); d += 8) {
///     acc_vec += data_vec * data_vec;             // lane-wise mul+add
///   }
///   acc_vec.store(buffer);
///   for (j = 1; j < 8; j++) buffer[0] += buffer[j];  // naive LEFT-FOLD
///   for (; d < size; d++) buffer[0] += data*data;    // scalar tail
///   result = sqrt(buffer[0]);
/// ```
///
/// The accumulator type is `at::opmath_type<float> == float` (`OpMathType.h:16`),
/// i.e. f32, NOT f64. The lane accumulate uses plain multiply-then-add (AVX2
/// `_mm256_mul_ps` + `_mm256_add_ps`, `vec256_float.h:564`, NOT fused); the
/// scalar tail's `buffer[0] += data*data` compiles to a fused multiply-add, so
/// we use [`f32::mul_add`] there. `NormTwoOps::project` is `device_sqrt`
/// (`SharedReduceOps.h:375-381`), i.e. `f32::sqrt`.
///
/// A scalar model of the tree (no `unsafe` SIMD intrinsics) is used for
/// portability and determinism: f32 `a + b*b` without contraction is
/// bit-identical to AVX2 `_mm256_add_ps(a, _mm256_mul_ps(b, b))` lane-wise, so
/// the scalar lane model reproduces the AVX2 path's rounding exactly.
#[must_use]
pub fn l2_norm_f32_torch(data: &[f32]) -> f32 {
    let n = data.len();
    // 8 lane accumulators, all zero — mirrors `fVec acc_vec{acc_t(0)}`.
    let mut lanes = [0.0_f32; F32_LANES];

    // Main loop: process contiguous chunks of 8 elements, accumulating each
    // element's square into its lane. `main` is `size - (size % 8)`, matching
    // the kernel's `d < size - (size % Vec::size())` bound. Lane accumulate is
    // plain mul-then-add (NOT fused), mirroring `acc_fvec += data_fvec *
    // data_fvec` lowered to `_mm256_add_ps(acc, _mm256_mul_ps(d, d))`.
    let main = n - (n % F32_LANES);
    let mut d = 0;
    while d < main {
        for (j, lane) in lanes.iter_mut().enumerate() {
            let x = data[d + j];
            // `+= x*x` lowers to a single f32 add of the product (no FMA),
            // bit-identical to AVX2 `_mm256_add_ps(acc, _mm256_mul_ps(x, x))`.
            *lane += x * x;
        }
        d += F32_LANES;
    }

    // Horizontal reduce: naive LEFT-FOLD of the 8 lanes into lane 0, exactly as
    // the kernel does (`for (j = 1; j < fVec::size(); j++) buffer[0] += buffer[j]`).
    // FP addition is non-associative, so the left-fold order is load-bearing —
    // this is NOT the AVX2 `vec_reduce_all` permute tree (that tree is used by
    // sum/prod's `binary_kernel_reduce_vec`, not by the norm last-dim kernel).
    let mut acc = lanes[0];
    for &lane in &lanes[1..] {
        acc += lane;
    }

    // Scalar remainder tail: the `size % 8` elements that didn't fill a full
    // 8-wide chunk accumulate into lane 0 (`for (; d < size; d++) buffer[0] +=
    // data_val * data_val`). The compiled `buffer[0] += data*data` contracts to
    // a fused multiply-add (single rounding) under PyTorch's `-ffp-contract`,
    // so we use `mul_add` to match it.
    while d < n {
        let x = data[d];
        acc = x.mul_add(x, acc);
        d += 1;
    }

    // `NormTwoOps::project` is `device_sqrt(a)` (SharedReduceOps.h:375-381),
    // i.e. f32 sqrt of the accumulated sum-of-squares.
    acc.sqrt()
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Helper: assert that `l2_norm_f32_torch(data)` produces EXACTLY the f32
    /// bit pattern `torch_bits` that live torch 2.11 `at::norm(2.0)` produced
    /// for the same input (R-CHAR-3: `torch_bits` is the live-oracle value,
    /// not copied from ferrotorch).
    #[track_caller]
    fn assert_torch_bits(data: &[f32], torch_bits: u32) {
        let got = l2_norm_f32_torch(data);
        assert_eq!(
            got.to_bits(),
            torch_bits,
            "l2_norm_f32_torch({data:?}) = {got} (bits {:#010x}); \
             live torch at::norm(2.0) f32 = {} (bits {torch_bits:#010x})",
            got.to_bits(),
            f32::from_bits(torch_bits)
        );
    }

    /// The #1612 / #1614 boundary row. Live torch `at::norm(2.0)` f32 produces
    /// bits `0x4201970d` (== 32.39751052856445). A scalar `Σ v*v` (or the old
    /// `Σ powf(|v|,2)`) gives `0x4201970e`, one ULP high — which flips the
    /// `max_norm` renorm decision. This primitive must reproduce `0x4201970d`.
    /// Oracle: live torch 2.11.0+cu130, 2026-05-28.
    #[test]
    fn matches_torch_boundary_row_1614() {
        let row = [
            3.6006885_f32,
            18.799816,
            0.4159323,
            -2.6984854,
            -4.786058,
            25.550726,
        ];
        assert_torch_bits(&row, 0x4201_970d);
    }

    /// Length-8 row (exactly one full 8-wide lane chunk, empty tail). Oracle:
    /// live torch 2.11 `at::norm(2.0)` f32 = bits `0x42547be4`.
    #[test]
    fn matches_torch_len8() {
        let row = [
            8.36561_f32,
            -28.49935,
            -13.49824,
            -16.60736,
            14.18827,
            10.60197,
            23.53077,
            -24.78367,
        ];
        assert_torch_bits(&row, 0x4254_7be4);
    }

    /// Length-6 random row (pure scalar tail, no full lane chunk). Oracle: live
    /// torch 2.11 `at::norm(2.0)` f32 = bits `0x423abbaf`.
    #[test]
    fn matches_torch_len6_remainder() {
        let row = [-24.30948_f32, 23.0681, 5.86093, 0.7079, -25.30966, 19.51437];
        assert_torch_bits(&row, 0x423a_bbaf);
    }

    /// Length-7 random row (7-element scalar tail). Oracle: live torch 2.11
    /// `at::norm(2.0)` f32 = bits `0x4258780e`.
    #[test]
    fn matches_torch_len7_remainder() {
        let row = [
            -1.91376_f32,
            23.78282,
            -27.81234,
            0.16903,
            22.27274,
            24.92021,
            21.6505,
        ];
        assert_torch_bits(&row, 0x4258_780e);
    }

    /// Length-13 row (one 8-wide chunk + a 5-element tail). Oracle: live torch
    /// 2.11 `at::norm(2.0)` f32 = bits `0x425c3351`.
    #[test]
    fn matches_torch_len13() {
        let row = [
            -2.20465_f32,
            28.78827,
            11.0945,
            -3.93455,
            -4.74478,
            22.8556,
            -1.42997,
            -19.06973,
            -13.36156,
            -22.74215,
            5.86617,
            -19.91655,
            4.5731,
        ];
        assert_torch_bits(&row, 0x425c_3351);
    }

    /// Length-16 row (two full 8-wide chunks, empty tail). Oracle: live torch
    /// 2.11 `at::norm(2.0)` f32 = bits `0x42a3f203`.
    #[test]
    fn matches_torch_len16() {
        let row = [
            -7.07259_f32,
            -28.7053,
            17.81876,
            -24.27542,
            27.74416,
            14.62365,
            -12.6029,
            29.56241,
            6.84485,
            5.97385,
            -17.96887,
            -20.27365,
            19.81938,
            -22.51791,
            28.62067,
            -19.66962,
        ];
        assert_torch_bits(&row, 0x42a3_f203);
    }

    /// Length-17 row (two 8-wide chunks + a 1-element tail). Oracle: live torch
    /// 2.11 `at::norm(2.0)` f32 = bits `0x429fcbe7` (this particular 17-row IS
    /// byte-exact under the model — not every odd-tail length lands off).
    #[test]
    fn matches_torch_len17() {
        let row = [
            8.04938_f32,
            -29.45572,
            26.4767,
            -1.62244,
            27.07591,
            -16.99281,
            8.93223,
            -3.06544,
            13.78239,
            27.94275,
            18.62917,
            -22.4124,
            -23.7873,
            11.43891,
            -16.49435,
            -28.32381,
            6.74599,
        ];
        assert_torch_bits(&row, 0x429f_cbe7);
    }

    /// A KNOWN-RESIDUAL row (R-HONEST-3): the portable model lands ONE ULP
    /// below live torch here. This documents the ~3% residual honestly rather
    /// than hiding it. Live torch 2.11 `at::norm(2.0)` f32 = bits `0x40c9a36f`;
    /// the model gives `0x40c9a36e` (one ULP low) because torch's compiled
    /// scalar remainder contracts FMA in a pattern this length-5 tail can't
    /// reproduce. The #1612 / #1614 boundary cases are NOT in this residual
    /// set. The model value `0x40c9a36e` is pinned so a future change to the
    /// algorithm that moves this row is caught; the live-torch value is
    /// recorded alongside so the divergence direction stays auditable.
    #[test]
    fn known_residual_one_ulp_below_torch() {
        let row = [0.60962_f32, 2.0169, -3.36223, 4.05906, 2.73588];
        let got = l2_norm_f32_torch(&row);
        const TORCH_BITS: u32 = 0x40c9_a36f;
        assert_eq!(
            got.to_bits(),
            0x40c9_a36e,
            "residual len-5 row: model bits should be the known 0x40c9a36e \
             (one ULP below live torch {TORCH_BITS:#010x}); if this changes, \
             re-derive the model against the oracle"
        );
        assert!(
            (i64::from(got.to_bits()) - i64::from(TORCH_BITS)).abs() <= 1,
            "residual must stay within one ULP of live torch"
        );
    }

    /// Single-element row: the norm of `[x]` is `|x|`. For `x = 29.04990959`,
    /// live torch `at::norm(2.0)` f32 = bits `0x41e86637`.
    #[test]
    #[allow(
        clippy::excessive_precision,
        reason = "29.04990959 is the verbatim torch input scalar for the len-1 \
                  norm oracle; kept for provenance against at::norm(2.0) f32"
    )]
    fn matches_torch_len1() {
        let row = [29.04990959_f32];
        assert_torch_bits(&row, 0x41e8_6637);
    }

    /// Empty slice norms to 0 (sum of zero squares, sqrt(0) == 0).
    #[test]
    fn empty_is_zero() {
        assert_eq!(l2_norm_f32_torch(&[]).to_bits(), 0.0_f32.to_bits());
    }

    /// The boundary row's norm must NOT exceed itself: feeding the matched f32
    /// norm back as a `max_norm` reproduces torch's "not greater, do not clip"
    /// decision. This is the exact comparison `embedding_renorm_cpu_` makes
    /// (`Embedding.cpp:204`, `norm > max_norm`).
    #[test]
    fn boundary_decision_does_not_clip() {
        let row = [
            3.6006885_f32,
            18.799816,
            0.4159323,
            -2.6984854,
            -4.786058,
            25.550726,
        ];
        let norm = l2_norm_f32_torch(&row);
        // torch's f32 norm widened to f64 (== `.item<double>()` at
        // Embedding.cpp:203) used as the max_norm threshold.
        let max_norm = f64::from(norm);
        #[allow(
            clippy::neg_cmp_op_on_partial_ord,
            reason = "deliberately mirrors torch's `norm > max_norm` decision \
                      (Embedding.cpp:204) verbatim — asserting it is false at the \
                      boundary; `<=` would obscure the upstream comparison operator"
        )]
        let does_not_clip = !(f64::from(norm) > max_norm);
        assert!(
            does_not_clip,
            "norm > max_norm must be false at the boundary (torch does not clip)"
        );
    }
}