tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! IEEE 754 binary16 <-> binary32 conversion helpers.
//!
//! Pure-Rust conversions used by the lib (e.g. `training_runner`
//! wiring the batched binary AdamW step) and by the HIP
//! fp16 <-> fp32 boundary. No platform intrinsics; the conversions
//! are bit-level and round-trip safe.
//!
// IEEE 754 binary16 <-> binary32 conversion helpers shared between
// the lib (e.g. `training_runner` wiring the batched binary AdamW
// into the per-parameter step loop) and the `hip_adamw` test
// suite. The HIP kernel takes fp16 weights/gradients and fp32
// moments; the model layer stores everything as fp32, so every
// wire-format round-trip needs a host-side conversion.
//
// The conversion matches HIP's `__half2float` (decode) and
// `__float2half_rn` (encode round-to-nearest-even) bit-exactly so
// a value uploaded to the GPU and read back to the host is
// byte-identical to the original.

/// IEEE 754 binary16 -> binary32 conversion (matches HIP
/// `__half2float`).
pub fn f16_to_f32(bits: u16) -> f32 {
    let sign = ((bits >> 15) & 1) as u32;
    let exp = ((bits >> 10) & 0x1f) as u32;
    let frac = (bits & 0x3ff) as u32;
    let f32_bits = if exp == 0 {
        if frac == 0 {
            sign << 31
        } else {
            let mut e: i32 = -14;
            let mut f = frac;
            while (f & 0x400) == 0 {
                f <<= 1;
                e -= 1;
            }
            f &= 0x3ff;
            let f32_exp = (e + 127) as u32;
            (sign << 31) | (f32_exp << 23) | (f << 13)
        }
    } else if exp == 0x1f {
        (sign << 31) | (0xff << 23) | (frac << 13) | u32::from(frac != 0)
    } else {
        let f32_exp = exp + 112;
        (sign << 31) | (f32_exp << 23) | (frac << 13)
    };
    f32::from_bits(f32_bits)
}

/// IEEE 754 binary32 -> binary16 round-to-nearest-even conversion
/// (matches HIP `__float2half_rn`).
pub fn f32_to_f16(value: f32) -> u16 {
    let bits = value.to_bits();
    let sign = ((bits >> 31) & 1) as u16;
    let exp = ((bits >> 23) & 0xff) as i32;
    let frac = bits & 0x7fffff;
    if exp == 0xff {
        return (sign << 15) | 0x7c00 | u16::from(frac != 0);
    }
    let new_exp = exp - 127 + 15;
    if new_exp >= 0x1f {
        return (sign << 15) | 0x7c00;
    }
    if new_exp <= 0 {
        if new_exp < -10 {
            return sign << 15;
        }
        // Subnormal fp16: the 24-bit mantissa (frac | 0x800000)
        // shifts right by (14 - new_exp) to land in the 10-bit
        // subnormal fraction field. For new_exp = 0 the shift
        // is 14 (largest subnormal exponent); for new_exp = -10
        // the shift is 24 (the full mantissa width, so the
        // rounded result is 0 or 1). The rounding bit sits at
        // `shift - 1` and the sticky mask is everything below
        // it. If rounding pushes the mantissa over `0x400`, the
        // result is the smallest NORMAL fp16 (2^-14), not a
        // subnormal — that case is handled below.
        let m = frac | 0x800000u32;
        let shift = (14 - new_exp) as u32;
        let round_bit = 1u32 << (shift - 1);
        let sticky = (m & (round_bit - 1)) != 0;
        let mut rounded = m >> shift;
        if (m & round_bit) != 0 && (sticky || (rounded & 1) == 1) {
            rounded += 1;
        }
        if rounded == 0x400 {
            // Rounded up across the subnormal -> normal boundary.
            // Smallest normal fp16 is 2^-14, encoded as
            // sign=., exp_field=1, frac=0 = `1u16 << 10` (= 0x0400).
            // The previous code returned `16u16 << 10` (= 0x4000),
            // which is the encoding for value ±2.0. That bug
            // silently rounded any fp32 value just below 2^-14 to
            // ±2.0, producing 2.0 absolute errors on subnormal-boundary
            // reductions in cpu_gemm_bw_grad_a etc.
            return (sign << 15) | (1u16 << 10);
        }
        return (sign << 15) | (rounded as u16 & 0x3ff);
    }
    let round_bit = 1u32 << 12;
    let sticky = (frac & (round_bit - 1)) != 0;
    let mut rounded = (frac >> 13) & 0x3ff;
    if (frac & round_bit) != 0 && (sticky || (rounded & 1) == 1) {
        rounded += 1;
        if rounded == 0x400 {
            return (sign << 15) | ((new_exp as u16 + 1) << 10);
        }
    }
    (sign << 15) | ((new_exp as u16) << 10) | (rounded as u16)
}

#[cfg(test)]
mod tests {
    use super::*;

    fn roundtrip(value: f32) -> f32 {
        f16_to_f32(f32_to_f16(value))
    }

    #[test]
    fn roundtrip_preserves_zero_and_one() {
        assert_eq!(roundtrip(0.0f32).to_bits(), 0.0f32.to_bits());
        assert_eq!(roundtrip(-0.0f32).to_bits(), (-0.0f32).to_bits());
        assert_eq!(roundtrip(1.0f32).to_bits(), 1.0f32.to_bits());
        assert_eq!(roundtrip(-1.0f32).to_bits(), (-1.0f32).to_bits());
    }

    #[test]
    fn roundtrip_finite_values_within_one_ulp() {
        // Pick a few values that exercise subnormal/exponent
        // boundary behavior; the round-trip should land within 1
        // ULP of the input.
        for v in [
            0.1f32, 0.5, 0.99999, 1.0e-3, 1.0e3, -3.14, 65504.0, -65504.0,
        ] {
            let rt = roundtrip(v);
            let ulp = (v - rt).abs();
            assert!(
                ulp <= v.abs() * 5e-4 + 1e-7,
                "roundtrip of {v} gave {rt}, ulp={ulp}"
            );
        }
    }

    #[test]
    fn f16_to_f32_decode_matches_known_patterns() {
        // 1.0 in fp16 is sign=0, exp=15, frac=0 -> 0x3C00
        assert_eq!(f16_to_f32(0x3C00).to_bits(), 1.0f32.to_bits());
        // 2.0 in fp16 is sign=0, exp=16, frac=0 -> 0x4000
        assert_eq!(f16_to_f32(0x4000).to_bits(), 2.0f32.to_bits());
        // -1.0 is 0xBC00
        assert_eq!(f16_to_f32(0xBC00).to_bits(), (-1.0f32).to_bits());
        // 0.5 is 0x3800
        assert_eq!(f16_to_f32(0x3800).to_bits(), 0.5f32.to_bits());
    }

    #[test]
    fn f32_to_f16_encode_matches_known_patterns() {
        assert_eq!(f32_to_f16(1.0), 0x3C00);
        assert_eq!(f32_to_f16(2.0), 0x4000);
        assert_eq!(f32_to_f16(-1.0), 0xBC00);
        assert_eq!(f32_to_f16(0.5), 0x3800);
    }

    #[test]
    fn f32_to_f16_subnormal_encode_matches_known_patterns() {
        // Regression for the subnormal-shift bug. Before the fix,
        // the encoder used `(frac|0x800000) >> (1-new_exp)` then
        // `>> (13-new_exp)`, which combined shift by
        // `14 - 2*new_exp` instead of `14 - new_exp`. For
        // `2^-23` (new_exp=-8) that returned 0 instead of
        // `0x0002`, silently zeroing any weight/grad in
        // `(0, 2^-14)` and stalling training at initialization
        // (init scales like `1/sqrt(fan_in)` for an
        // ~1024-fan-in layer land near `2^-10`, well within
        // the buggy range).
        //
        // Verified reference values match HIP `__float2half_rn`:
        //   2^-24 -> 0x0001  (smallest subnormal)
        //   2^-23 -> 0x0002
        //   2^-22 -> 0x0004
        //   2^-16 -> 0x0100
        //   2^-15 -> 0x0200  (largest subnormal)
        //   2^-14 -> 0x0400  (smallest normal)
        //   2^-10 -> 0x1400  (normal; exp_field=5)
        let p2 = |e: i32| -> f32 { f32::from_bits(((127 - e) as u32) << 23) };
        assert_eq!(f32_to_f16(p2(24)), 0x0001, "2^-24 (smallest subnormal)");
        assert_eq!(f32_to_f16(p2(23)), 0x0002, "2^-23");
        assert_eq!(f32_to_f16(p2(22)), 0x0004, "2^-22");
        assert_eq!(f32_to_f16(p2(16)), 0x0100, "2^-16");
        assert_eq!(f32_to_f16(p2(15)), 0x0200, "2^-15 (largest subnormal)");
        assert_eq!(f32_to_f16(p2(14)), 0x0400, "2^-14 (smallest normal)");
        assert_eq!(f32_to_f16(p2(10)), 0x1400, "2^-10 (normal)");
        // Negative subnormals preserve sign.
        assert_eq!(f32_to_f16(-p2(23)), 0x8002, "-2^-23");
    }

    #[test]
    fn f32_to_f16_subnormal_rounds_to_nearest_even() {
        // Midpoint between 2^-24 and 2^-23: value = 1.5 * 2^-24.
        // Round-to-nearest-even picks the EVEN mantissa, i.e.
        // 0x0002 (mantissa 2) rather than 0x0001 (mantissa 1).
        // This matches the normal-path's RNE semantics in the
        // encoder.
        let v = 1.5f32 * 2f32.powi(-24);
        assert_eq!(f32_to_f16(v), 0x0002, "1.5 * 2^-24 -> 0x0002 (RNE)");
    }

    #[test]
    fn f32_to_f16_subnormal_rounds_up_to_smallest_normal_not_to_two() {
        // Regression: a previous bug returned (16u16 << 10) = 0x4000
        // (encoding for value 2.0) instead of (1u16 << 10) = 0x0400
        // (encoding for smallest normal fp16, 2^-14) when an fp32
        // value just below the subnormal/normal boundary rounded up.
        //
        // Concrete repro: cpu_gemm_bw_grad_a's reduction at
        // (row=58, col=69) of a 1024x1024 random fp16 input lands
        // at -0.0000610128 (fp32). Its f32_bits are exp=112,
        // mantissa=0x7ffeb1, which after shift-14 and RNE rounds
        // up to mantissa 0x400, the smallest-normal boundary. The
        // correct fp16 is 0x8400 (-2^-14 ≈ -0.0000610352); the bug
        // produced 0xc000 (-2.0), masking itself as a 2.0 GPU error.
        let v = -0.0000610128f32;
        assert_eq!(
            f32_to_f16(v),
            0x8400,
            "fp32 value just below -2^-14 should round to smallest negative normal fp16 (0x8400), not -2.0 (0xc000)"
        );
        // The positive side too: just under +2^-14 should also
        // round up to +smallest_normal, not +2.0.
        let vp = 0.0000610128f32;
        assert_eq!(
            f32_to_f16(vp),
            0x0400,
            "fp32 value just below +2^-14 should round to smallest positive normal fp16 (0x0400), not 2.0 (0x4000)"
        );
    }

    #[test]
    fn f32_to_f16_boundary_exhaustive() {
        // Exhaustive sweep of every fp32 value with `exp_field=112`
        // (the largest subnormal exponent in fp16, spanning
        // `[2^-15, 2^-14)` in fp32). For each of the 2^24 values
        // (frac in `[0, 0x7fffff]`, both signs), verify the
        // returned fp16 has `exp_field` in `{0, 1}` — i.e. either a
        // subnormal (exp_field=0) or the smallest normal
        // (exp_field=1). It must NEVER be 16, which would encode
        // value 2.0 (the 0x4000 vs 0x0400 bug from Task #75).
        //
        // The previous encoder returned `(16u16 << 10)` when an
        // fp32 value just below the boundary rounded up across it
        // to the smallest-normal mantissa 0x400, instead of
        // `(1u16 << 10)`. That off-by-15 in the exponent field
        // silently turned `±2^-14` reductions into `±2.0` and
        // masked itself as a 2.0 GPU error.
        let mut first_violation: Option<(u32, u16)> = None;
        'outer: for frac in 0u32..=0x7fffffu32 {
            let pos_bits: u32 = (112u32 << 23) | frac;
            let neg_bits: u32 = pos_bits | (1u32 << 31);
            for bits in [pos_bits, neg_bits] {
                let v = f32::from_bits(bits);
                let h = f32_to_f16(v);
                let exp_field = h & 0x7c00;
                if exp_field > 0x0400 {
                    first_violation = Some((bits, h));
                    break 'outer;
                }
            }
        }
        if let Some((bits, h)) = first_violation {
            panic!(
                "f32_to_f16({:#010x}) = {:#06x}: exp_field={:#x} > 0x0400. \
                 For fp32 values strictly below 2^-14, the returned fp16 \
                 must have exp_field in {{0, 1}} (subnormal or smallest \
                 normal), never 16 (which encodes value 2.0 — the \
                 subnormal-boundary round-up bug from Task #75).",
                bits,
                h,
                h & 0x7c00
            );
        }
    }

    #[test]
    fn f32_to_f16_roundtrip_property_random() {
        // Property test: ~1000 random positive fp32 values across the
        // full range (NaN/Inf and >65504 skipped), seeded with an
        // LCG so the test is reproducible without adding a `rand`
        // dependency. For each value v we check:
        //   1. Round-trip correctness: f32 -> f16 -> f32' satisfies
        //      |v - v'| <= |v| * 5e-4 + 1e-7 (the same relative
        //      bound as the existing `roundtrip_finite_values_*`
        //      test, which is tight against fp16's 10-bit mantissa).
        //   2. Boundary property: when v has fp32 exp_field=112
        //      (i.e. v in [2^-15, 2^-14)), the fp16 exp_field is
        //      in {0, 1} — never 16.
        let mut rng_state: u64 = 0x12345678_9abcdef0_u64;
        let mut next_u32 = || -> u32 {
            // Numerical-Recipes LCG (same family as MMIX).
            rng_state = rng_state
                .wrapping_mul(6364136223846793005_u64)
                .wrapping_add(1442695040888963407_u64);
            (rng_state >> 33) as u32
        };

        let mut checked: u32 = 0;
        let mut attempts: u32 = 0;
        while checked < 1000 && attempts < 5000 {
            attempts += 1;
            let pos_bits = next_u32() & 0x7fff_ffff_u32; // strip sign
            // Skip NaN / +inf (sign stripped already, so only +inf).
            if (pos_bits & 0x7f80_0000_u32) == 0x7f80_0000_u32 {
                continue;
            }
            let v = f32::from_bits(pos_bits);
            // Skip values that round up to +inf in fp16 (largest
            // finite fp16 is 65504.0, so anything above that will
            // have an infinite round-trip and the relative bound
            // is meaningless). The encoder correctly returns
            // 0x7c00 in this case.
            if v > 65504.0_f32 {
                checked += 1;
                continue;
            }
            let h = f32_to_f16(v);
            let v_round = f16_to_f32(h);
            let err = (v - v_round).abs();
            assert!(
                err <= v.abs() * 5e-4 + 1e-7,
                "roundtrip of v={} (bits {:#010x}) gave v'={}, err={}",
                v,
                pos_bits,
                v_round,
                err
            );
            // For v in [2^-15, 2^-14), the fp16 exp field is 0
            // (subnormal) or 1 (smallest normal) — never 16.
            let exp = (pos_bits >> 23) & 0xff;
            if exp == 112 {
                let exp_field = h & 0x7c00;
                assert!(
                    exp_field <= 0x0400,
                    "fp32 bits {:#010x} (v={}) in [2^-15, 2^-14) must \
                     round to fp16 exp_field in {{0, 1}}, got {:#06x}",
                    pos_bits,
                    v,
                    h
                );
            }
            checked += 1;
        }
        assert!(
            checked >= 500,
            "not enough valid random samples collected: {}",
            checked
        );
    }
}