trueno 0.17.1

High-performance SIMD compute library with GPU support for matrix operations
Documentation
//! NF4 (4-bit NormalFloat) Quantization — transpiled from bitsandbytes
//!
//! Safe Rust implementation of bitsandbytes' NF4 quantization/dequantization.
//! Algorithm source: `bitsandbytes/csrc/kernels.cu:26-153` (MIT licensed).
//! Transpiled via manual conversion from C (decy Tier 1: pure math, zero unsafe).
//!
//! # NF4 Codebook (Dettmers et al. 2023 "QLoRA" §3.1)
//!
//! The 16-entry lookup table contains quantiles of the standard normal distribution,
//! normalized to [-1.0, 1.0]. NF4 is information-theoretically optimal for normally
//! distributed weights (which transformer weights empirically are).
//!
//! # Contract: nf4-dequantization-v1
//!
//! - `nf4_codebook`: LUT has 16 entries, monotonically increasing, LUT[0]=-1, LUT[15]=1
//! - `quantize_roundtrip`: quantize(dequant(code)) == code for all 16 codes
//! - `blockwise_dequant`: x_i = NF4_LUT[nibble] × absmax[i / blocksize]

/// NF4 dequantization lookup table — 16 quantiles of the standard normal distribution.
///
/// Source: `bitsandbytes/csrc/kernels.cu:26-43`
/// Values: Φ⁻¹((i + 0.5) / 16) normalized to [-1.0, 1.0]
pub const NF4_LUT: [f32; 16] = [
    -1.0,                 // 0b0000
    -0.6961928009986877,  // 0b0001
    -0.5250730514526367,  // 0b0010
    -0.39491748809814453, // 0b0011
    -0.28444138169288635, // 0b0100
    -0.18477343022823334, // 0b0101
    -0.09105003625154495, // 0b0110
    0.0,                  // 0b0111
    0.07958029955625534,  // 0b1000
    0.16093020141124725,  // 0b1001
    0.24611230194568634,  // 0b1010
    0.33791524171829224,  // 0b1011
    0.44070982933044434,  // 0b1100
    0.5626170039176941,   // 0b1101
    0.7229568362236023,   // 0b1110
    1.0,                  // 0b1111
];

/// Dequantize a single NF4 code to f32.
///
/// Source: `bitsandbytes/csrc/kernels.cu:108`
/// `__device__ __forceinline__ float dDequantizeNF4(unsigned char val)`
#[inline]
pub fn dequantize_nf4(val: u8) -> f32 {
    NF4_LUT[(val & 0x0F) as usize]
}

/// Quantize a normalized f32 value to a 4-bit NF4 code.
///
/// Input should be in [-1.0, 1.0] (pre-normalized by absmax).
/// Returns a value in [0, 15].
///
/// Source: `bitsandbytes/csrc/kernels.cu:110-153`
/// Binary search tree with 15 hardcoded threshold midpoints.
#[inline]
pub fn quantize_nf4(x: f32) -> u8 {
    // Binary search tree generated by test_normal_map_tree (bitsandbytes)
    if x > 0.039_790_15 {
        if x > 0.389_312_54 {
            if x > 0.642_786_9 {
                if x > 0.861_478_4 {
                    0b1111
                } else {
                    0b1110
                }
            } else if x > 0.501_663_4 {
                0b1101
            } else {
                0b1100
            }
        } else if x > 0.203_521_25 {
            if x > 0.292_013_77 {
                0b1011
            } else {
                0b1010
            }
        } else if x > 0.120_255_25 {
            0b1001
        } else {
            0b1000
        }
    } else if x > -0.339_679_43 {
        if x > -0.137_911_73 {
            if x > -0.045_525_018 {
                0b0111
            } else {
                0b0110
            }
        } else if x > -0.234_607_41 {
            0b0101
        } else {
            0b0100
        }
    } else if x > -0.610_632_93 {
        if x > -0.459_995_27 {
            0b0011
        } else {
            0b0010
        }
    } else if x > -0.848_096_4 {
        0b0001
    } else {
        0b0000
    }
}

/// Dequantize a block of NF4-packed bytes to f32.
///
/// Each input byte contains two NF4 values:
/// - High nibble (byte >> 4): first element
/// - Low nibble (byte & 0x0F): second element
///
/// Output has exactly 2× the length of input.
///
/// Source: `bitsandbytes/csrc/kernels.cu:465-529` (kDequantizeBlockwise algorithm)
pub fn dequantize_blockwise(packed: &[u8], absmax: &[f32], blocksize: usize, output: &mut [f32]) {
    assert_eq!(output.len(), packed.len() * 2, "output must be 2× packed length");
    // NF4 packs 2 values per byte, so blocksize refers to elements

    for (byte_idx, &byte) in packed.iter().enumerate() {
        let elem_idx = byte_idx * 2;
        let block_idx = elem_idx / blocksize;
        let scale = absmax[block_idx];

        let high = (byte >> 4) & 0x0F;
        let low = byte & 0x0F;

        output[elem_idx] = NF4_LUT[high as usize] * scale;
        if elem_idx + 1 < output.len() {
            output[elem_idx + 1] = NF4_LUT[low as usize] * scale;
        }
    }
}

/// Quantize a block of f32 values to NF4-packed bytes.
///
/// Each output byte contains two NF4 values packed as (high << 4) | low.
///
/// Source: `bitsandbytes/csrc/kernels.cu:269-375` (kQuantizeBlockwise algorithm)
pub fn quantize_blockwise(input: &[f32], blocksize: usize, packed: &mut [u8], absmax: &mut [f32]) {
    assert_eq!(packed.len(), (input.len() + 1) / 2, "packed must be ceil(input/2)");
    let num_blocks = (input.len() + blocksize - 1) / blocksize;
    assert!(absmax.len() >= num_blocks, "absmax too small");

    // Pass 1: compute per-block absmax
    for block in 0..num_blocks {
        let start = block * blocksize;
        let end = (start + blocksize).min(input.len());
        let mut max_val: f32 = 0.0;
        for &v in &input[start..end] {
            let abs = v.abs();
            if abs > max_val {
                max_val = abs;
            }
        }
        absmax[block] = max_val;
    }

    // Pass 2: normalize and quantize, pack two per byte
    for byte_idx in 0..packed.len() {
        let elem_idx = byte_idx * 2;
        let block_idx = elem_idx / blocksize;
        let scale = absmax[block_idx];
        let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };

        let high = quantize_nf4(input[elem_idx] * inv_scale);
        let low = if elem_idx + 1 < input.len() {
            let block_idx_low = (elem_idx + 1) / blocksize;
            let inv_scale_low =
                if absmax[block_idx_low] > 0.0 { 1.0 / absmax[block_idx_low] } else { 0.0 };
            quantize_nf4(input[elem_idx + 1] * inv_scale_low)
        } else {
            0
        };

        packed[byte_idx] = (high << 4) | low;
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// FALSIFY-NF4-001: LUT is strictly monotonically increasing
    #[test]
    fn test_nf4_lut_monotonic() {
        for i in 0..15 {
            assert!(
                NF4_LUT[i] < NF4_LUT[i + 1],
                "NF4_LUT[{}]={} >= NF4_LUT[{}]={}",
                i,
                NF4_LUT[i],
                i + 1,
                NF4_LUT[i + 1]
            );
        }
    }

    /// FALSIFY-NF4-001 (continued): boundary values
    #[test]
    fn test_nf4_lut_boundaries() {
        assert_eq!(NF4_LUT[0], -1.0);
        assert_eq!(NF4_LUT[7], 0.0);
        assert_eq!(NF4_LUT[15], 1.0);
    }

    /// FALSIFY-NF4-002: Roundtrip fidelity — quantize(dequant(code)) == code for all 16
    #[test]
    fn test_nf4_roundtrip_exhaustive() {
        for code in 0u8..16 {
            let dequantized = dequantize_nf4(code);
            let requantized = quantize_nf4(dequantized);
            assert_eq!(
                code, requantized,
                "Roundtrip failed: code={} → dequant={} → requant={}",
                code, dequantized, requantized
            );
        }
    }

    /// FALSIFY-NF4-006: Nibble unpacking order
    #[test]
    fn test_nf4_nibble_order() {
        // Byte 0xAB should unpack to high=0xA=10, low=0xB=11
        let byte: u8 = 0xAB;
        let high = (byte >> 4) & 0x0F;
        let low = byte & 0x0F;
        assert_eq!(high, 10);
        assert_eq!(low, 11);
        // Verify dequant produces different values
        let dh = dequantize_nf4(high);
        let dl = dequantize_nf4(low);
        assert_ne!(dh, dl);
        assert_eq!(dh, NF4_LUT[10]);
        assert_eq!(dl, NF4_LUT[11]);
    }

    /// FALSIFY-NF4-003: Blockwise dequant/quantize roundtrip error bound
    #[test]
    fn test_nf4_blockwise_roundtrip() {
        let blocksize = 64;
        let n = 256;
        let input: Vec<f32> = (0..n).map(|i| (i as f32 / n as f32) * 2.0 - 1.0).collect();

        let mut packed = vec![0u8; n / 2];
        let mut absmax = vec![0.0f32; n / blocksize];
        quantize_blockwise(&input, blocksize, &mut packed, &mut absmax);

        let mut output = vec![0.0f32; n];
        dequantize_blockwise(&packed, &absmax, blocksize, &mut output);

        // Error bounded by bin width: max distance between adjacent NF4 quantiles × absmax.
        // The widest bin is between codes 14 and 15: (1.0 - 0.7230) = 0.277.
        // After scaling by absmax, max error ≈ 0.277 * absmax.
        for (i, (&orig, &deq)) in input.iter().zip(output.iter()).enumerate() {
            let block = i / blocksize;
            let max_error = 0.28 * absmax[block]; // widest NF4 bin
            let error = (orig - deq).abs();
            assert!(
                error <= max_error + 1e-6,
                "Block {} elem {}: error {} > max_error {}",
                block,
                i,
                error,
                max_error
            );
        }
    }

    /// quantize_nf4 always returns [0, 15]
    #[test]
    fn test_quantize_nf4_range() {
        for i in 0..1000 {
            let x = (i as f32 / 500.0) - 1.0; // [-1.0, 1.0]
            let code = quantize_nf4(x);
            assert!(code < 16, "quantize_nf4({}) = {} >= 16", x, code);
        }
    }
}