Skip to main content

entrenar/autograd/precision/
conversions.rs

1//! Precision conversion functions and memory estimation utilities.
2
3use super::Precision;
4
5/// Convert f32 to bf16 (truncated)
6///
7/// BF16 uses the same exponent as f32 but only 7 mantissa bits.
8pub fn f32_to_bf16(value: f32) -> u16 {
9    let bits = value.to_bits();
10    // Take upper 16 bits (sign + exponent + 7 mantissa bits)
11    (bits >> 16) as u16
12}
13
14/// Convert bf16 to f32
15pub fn bf16_to_f32(value: u16) -> f32 {
16    // Place in upper 16 bits, lower 16 are zeros
17    let bits = u32::from(value) << 16;
18    f32::from_bits(bits)
19}
20
21/// Convert f32 to fp16 (IEEE half precision)
22///
23/// ONE PATH: Delegates to `trueno::f32_to_f16` (UCBD §4).
24pub fn f32_to_fp16(value: f32) -> u16 {
25    trueno::f32_to_f16(value)
26}
27
28/// Convert fp16 to f32
29///
30/// ONE PATH: Delegates to `trueno::f16_to_f32` (UCBD §4).
31pub fn fp16_to_f32(value: u16) -> f32 {
32    trueno::f16_to_f32(value)
33}
34
35/// Truncate an f32 value to BF16 precision (zero lower 16 mantissa bits).
36///
37/// Equivalent to f32 → bf16 → f32 round-trip via bit truncation (not rounding).
38/// The result is a valid f32 with only 7 mantissa bits of precision.
39///
40/// # Contract (C-BF16GEMM-001)
41///
42/// - `bf16_truncate(x).to_bits() & 0x0000FFFF == 0` for all x
43/// - `bf16_truncate(NaN).is_nan()` and `bf16_truncate(Inf).is_infinite()`
44/// - `bf16_truncate(x) == bf16_to_f32(f32_to_bf16(x))` for all x
45#[inline]
46pub fn bf16_truncate(val: f32) -> f32 {
47    f32::from_bits(val.to_bits() & 0xFFFF_0000)
48}
49
50/// CPU reference implementation of BF16-precision GEMM.
51///
52/// Computes C = A @ B where A is MxK, B is KxN, but truncates each operand
53/// to BF16 precision before multiply, with FP32 accumulation.
54///
55/// This matches the precision characteristics of hardware BF16 tensor cores:
56/// - BF16 multiply (7-bit mantissa)
57/// - FP32 accumulation (23-bit mantissa)
58///
59/// Used for verification against GPU BF16 GEMM kernels.
60pub fn gemm_bf16_reference(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
61    let mut c = vec![0.0f32; m * n];
62    for row in 0..m {
63        for col in 0..n {
64            let mut acc = 0.0f32;
65            for i in 0..k {
66                let a_val = bf16_truncate(a[row * k + i]);
67                let b_val = bf16_truncate(b[i * n + col]);
68                acc = a_val.mul_add(b_val, acc);
69            }
70            c[row * n + col] = acc;
71        }
72    }
73    c
74}
75
76/// Estimate memory savings from mixed precision
77///
78/// # Arguments
79///
80/// * `num_params` - Number of model parameters
81/// * `batch_size` - Batch size
82/// * `seq_len` - Sequence length
83/// * `hidden_size` - Hidden dimension
84/// * `precision` - Target precision
85///
86/// # Returns
87///
88/// Tuple of (fp32_bytes, mixed_bytes, savings_ratio)
89pub fn estimate_memory_savings(
90    num_params: usize,
91    batch_size: usize,
92    seq_len: usize,
93    hidden_size: usize,
94    precision: Precision,
95) -> (usize, usize, f32) {
96    // FP32 memory: params + activations + gradients
97    let param_bytes_fp32 = num_params * 4;
98    let activation_bytes_fp32 = batch_size * seq_len * hidden_size * 4;
99    let grad_bytes_fp32 = num_params * 4;
100    let total_fp32 = param_bytes_fp32 + activation_bytes_fp32 + grad_bytes_fp32;
101
102    // Mixed precision: master weights (fp32) + activations (reduced) + gradients (reduced)
103    let param_bytes_mixed = num_params * 4; // Master weights in fp32
104    let activation_bytes_mixed = batch_size * seq_len * hidden_size * precision.size_bytes();
105    let grad_bytes_mixed = num_params * precision.size_bytes();
106    let total_mixed = param_bytes_mixed + activation_bytes_mixed + grad_bytes_mixed;
107
108    let savings = 1.0 - (total_mixed as f32 / total_fp32 as f32);
109    (total_fp32, total_mixed, savings)
110}