Skip to main content

cuda_rust_wasm/runtime/
bfloat16.rs

1//! BFloat16 (bf16) floating-point support
2//!
3//! Implements the Google Brain bfloat16 format used extensively in ML training.
4//! BF16 has the same exponent range as f32 (8 bits) but reduced mantissa (7 bits),
5//! making it ideal for training where range matters more than precision.
6//!
7//! Layout: 1 sign bit, 8 exponent bits, 7 mantissa bits.
8//! Range: same as f32 (±3.4×10³⁸), precision: ~2 decimal digits.
9
10use std::fmt;
11use std::ops::{Add, Sub, Mul, Div, Neg};
12
13/// BFloat16 — Google Brain's 16-bit floating-point format.
14///
15/// Unlike IEEE fp16 (Half), bf16 shares f32's exponent range, making it
16/// a drop-in replacement for f32 in training loops where dynamic range
17/// is more important than mantissa precision.
18#[derive(Clone, Copy, PartialEq, Eq, Hash)]
19pub struct BFloat16 {
20    bits: u16,
21}
22
23impl BFloat16 {
24    pub const ZERO: Self = Self { bits: 0x0000 };
25    pub const ONE: Self = Self { bits: 0x3F80 };
26    pub const NEG_ONE: Self = Self { bits: 0xBF80 };
27    pub const INFINITY: Self = Self { bits: 0x7F80 };
28    pub const NEG_INFINITY: Self = Self { bits: 0xFF80 };
29    pub const NAN: Self = Self { bits: 0x7FC0 };
30    pub const MAX: Self = Self { bits: 0x7F7F }; // ~3.39×10³⁸
31    pub const MIN_POSITIVE: Self = Self { bits: 0x0080 }; // smallest normal
32    pub const EPSILON: Self = Self { bits: 0x3C00 }; // 2^-7 ≈ 0.0078125
33
34    /// Create from raw u16 bits.
35    pub fn from_bits(bits: u16) -> Self {
36        Self { bits }
37    }
38
39    /// Get raw u16 bits.
40    pub fn to_bits(self) -> u16 {
41        self.bits
42    }
43
44    /// Convert from f32 (truncation, matching hardware behavior).
45    pub fn from_f32(value: f32) -> Self {
46        let bits = value.to_bits();
47        // Round to nearest even: check bit 16 (round bit) and bits 0-15 (sticky)
48        let round_bit = (bits >> 15) & 1;
49        let sticky = if bits & 0x7FFF != 0 { 1u32 } else { 0 };
50        let lsb = (bits >> 16) & 1;
51
52        // Round to nearest, ties to even
53        let rounded = (bits >> 16) + (round_bit & (sticky | lsb));
54
55        // Handle overflow to infinity
56        if (rounded & 0x7F80) == 0x7F80 && (bits & 0x7F800000) != 0x7F800000 {
57            // Rounding overflowed to inf, but original was finite
58            Self { bits: ((bits >> 16) & 0xFF80) as u16 | 0x7F }
59        } else {
60            Self { bits: rounded as u16 }
61        }
62    }
63
64    /// Convert to f32 (lossless — just pad lower 16 bits with zeros).
65    pub fn to_f32(self) -> f32 {
66        f32::from_bits((self.bits as u32) << 16)
67    }
68
69    /// Check if NaN.
70    pub fn is_nan(self) -> bool {
71        (self.bits & 0x7F80) == 0x7F80 && (self.bits & 0x007F) != 0
72    }
73
74    /// Check if infinite.
75    pub fn is_infinite(self) -> bool {
76        (self.bits & 0x7FFF) == 0x7F80
77    }
78
79    /// Check if finite (not NaN or infinite).
80    pub fn is_finite(self) -> bool {
81        (self.bits & 0x7F80) != 0x7F80
82    }
83
84    /// Check if zero (positive or negative).
85    pub fn is_zero(self) -> bool {
86        (self.bits & 0x7FFF) == 0
87    }
88
89    /// Check if the sign bit is set.
90    pub fn is_sign_negative(self) -> bool {
91        self.bits & 0x8000 != 0
92    }
93
94    /// Absolute value.
95    pub fn abs(self) -> Self {
96        Self { bits: self.bits & 0x7FFF }
97    }
98
99    /// Fused multiply-add: a * b + c (computed in f32).
100    pub fn fma(a: BFloat16, b: BFloat16, c: BFloat16) -> BFloat16 {
101        BFloat16::from_f32(a.to_f32().mul_add(b.to_f32(), c.to_f32()))
102    }
103
104    /// Square root.
105    pub fn sqrt(self) -> Self {
106        BFloat16::from_f32(self.to_f32().sqrt())
107    }
108
109    /// Reciprocal (1/x).
110    pub fn recip(self) -> Self {
111        BFloat16::from_f32(1.0 / self.to_f32())
112    }
113
114    /// Minimum of two values (NaN-propagating).
115    pub fn min(self, other: Self) -> Self {
116        if self.is_nan() || other.is_nan() {
117            return Self::NAN;
118        }
119        if self.to_f32() <= other.to_f32() { self } else { other }
120    }
121
122    /// Maximum of two values (NaN-propagating).
123    pub fn max(self, other: Self) -> Self {
124        if self.is_nan() || other.is_nan() {
125            return Self::NAN;
126        }
127        if self.to_f32() >= other.to_f32() { self } else { other }
128    }
129
130    /// Clamp value to [lo, hi].
131    pub fn clamp(self, lo: Self, hi: Self) -> Self {
132        self.max(lo).min(hi)
133    }
134}
135
136// ── Arithmetic ops ─────────────────────────────────────────────────
137
138impl Add for BFloat16 {
139    type Output = Self;
140    fn add(self, rhs: Self) -> Self {
141        BFloat16::from_f32(self.to_f32() + rhs.to_f32())
142    }
143}
144
145impl Sub for BFloat16 {
146    type Output = Self;
147    fn sub(self, rhs: Self) -> Self {
148        BFloat16::from_f32(self.to_f32() - rhs.to_f32())
149    }
150}
151
152impl Mul for BFloat16 {
153    type Output = Self;
154    fn mul(self, rhs: Self) -> Self {
155        BFloat16::from_f32(self.to_f32() * rhs.to_f32())
156    }
157}
158
159impl Div for BFloat16 {
160    type Output = Self;
161    fn div(self, rhs: Self) -> Self {
162        BFloat16::from_f32(self.to_f32() / rhs.to_f32())
163    }
164}
165
166impl Neg for BFloat16 {
167    type Output = Self;
168    fn neg(self) -> Self {
169        Self { bits: self.bits ^ 0x8000 }
170    }
171}
172
173impl PartialOrd for BFloat16 {
174    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
175        if self.is_nan() || other.is_nan() {
176            return None;
177        }
178        self.to_f32().partial_cmp(&other.to_f32())
179    }
180}
181
182impl fmt::Debug for BFloat16 {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "bf16({:.4})", self.to_f32())
185    }
186}
187
188impl fmt::Display for BFloat16 {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        write!(f, "{:.4}", self.to_f32())
191    }
192}
193
194impl From<f32> for BFloat16 {
195    fn from(v: f32) -> Self { BFloat16::from_f32(v) }
196}
197
198impl From<BFloat16> for f32 {
199    fn from(v: BFloat16) -> f32 { v.to_f32() }
200}
201
202// ── Batch operations ───────────────────────────────────────────────
203
204/// Convert an f32 slice to bf16.
205pub fn f32_to_bf16_slice(input: &[f32]) -> Vec<BFloat16> {
206    input.iter().map(|&v| BFloat16::from_f32(v)).collect()
207}
208
209/// Convert a bf16 slice to f32.
210pub fn bf16_to_f32_slice(input: &[BFloat16]) -> Vec<f32> {
211    input.iter().map(|v| v.to_f32()).collect()
212}
213
214/// Dot product of two bf16 slices, accumulated in f32.
215pub fn bf16_dot(a: &[BFloat16], b: &[BFloat16]) -> f32 {
216    a.iter().zip(b.iter()).map(|(x, y)| x.to_f32() * y.to_f32()).sum()
217}
218
219/// Matrix-vector multiply: y = A * x, with bf16 inputs and f32 accumulation.
220/// A is (rows × cols) row-major, x is (cols,), y is (rows,).
221pub fn bf16_gemv(a: &[BFloat16], x: &[BFloat16], rows: usize, cols: usize) -> Vec<f32> {
222    (0..rows).map(|r| {
223        let row_start = r * cols;
224        (0..cols).map(|c| {
225            a[row_start + c].to_f32() * x[c].to_f32()
226        }).sum()
227    }).collect()
228}
229
230/// Mixed-precision GEMM: C = A * B with bf16 inputs and f32 accumulation.
231/// A is (m × k), B is (k × n), C is (m × n).
232pub fn bf16_gemm(a: &[BFloat16], b: &[BFloat16], m: usize, k: usize, n: usize) -> Vec<f32> {
233    let mut c = vec![0.0f32; m * n];
234    for i in 0..m {
235        for p in 0..k {
236            let a_val = a[i * k + p].to_f32();
237            for j in 0..n {
238                c[i * n + j] += a_val * b[p * n + j].to_f32();
239            }
240        }
241    }
242    c
243}
244
245// ── Tests ──────────────────────────────────────────────────────────
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_bf16_roundtrip() {
253        let values = [0.0f32, 1.0, -1.0, 0.5, 100.0, -0.125, 3.14];
254        for &v in &values {
255            let bf = BFloat16::from_f32(v);
256            let back = bf.to_f32();
257            assert!((back - v).abs() < 0.05, "Roundtrip failed for {}: got {}", v, back);
258        }
259    }
260
261    #[test]
262    fn test_bf16_constants() {
263        assert_eq!(BFloat16::ZERO.to_f32(), 0.0);
264        assert_eq!(BFloat16::ONE.to_f32(), 1.0);
265        assert_eq!(BFloat16::NEG_ONE.to_f32(), -1.0);
266        assert!(BFloat16::INFINITY.is_infinite());
267        assert!(BFloat16::NAN.is_nan());
268        assert!(BFloat16::MAX.to_f32() > 1e38);
269    }
270
271    #[test]
272    fn test_bf16_arithmetic() {
273        let a = BFloat16::from_f32(2.0);
274        let b = BFloat16::from_f32(3.0);
275        assert!((a + b).to_f32() - 5.0 < 0.1);
276        assert!((a - b).to_f32() - (-1.0) < 0.1);
277        assert!((a * b).to_f32() - 6.0 < 0.1);
278        assert!(((a / b).to_f32() - 0.6667).abs() < 0.02);
279    }
280
281    #[test]
282    fn test_bf16_neg() {
283        let a = BFloat16::from_f32(42.0);
284        assert!((-a).to_f32() < 0.0);
285        assert!(((-a).to_f32() + 42.0).abs() < 0.5);
286    }
287
288    #[test]
289    fn test_bf16_comparison() {
290        let a = BFloat16::from_f32(1.0);
291        let b = BFloat16::from_f32(2.0);
292        assert!(a < b);
293        assert!(b > a);
294        assert!(BFloat16::NAN.partial_cmp(&a).is_none());
295    }
296
297    #[test]
298    fn test_bf16_special_values() {
299        assert!(BFloat16::NAN.is_nan());
300        assert!(!BFloat16::NAN.is_finite());
301        assert!(BFloat16::INFINITY.is_infinite());
302        assert!(!BFloat16::INFINITY.is_finite());
303        assert!(BFloat16::ZERO.is_zero());
304        assert!(BFloat16::from_bits(0x8000).is_zero()); // -0
305    }
306
307    #[test]
308    fn test_bf16_fma() {
309        let a = BFloat16::from_f32(2.0);
310        let b = BFloat16::from_f32(3.0);
311        let c = BFloat16::from_f32(1.0);
312        let result = BFloat16::fma(a, b, c);
313        assert!((result.to_f32() - 7.0).abs() < 0.1);
314    }
315
316    #[test]
317    fn test_bf16_sqrt() {
318        let a = BFloat16::from_f32(4.0);
319        assert!((a.sqrt().to_f32() - 2.0).abs() < 0.05);
320    }
321
322    #[test]
323    fn test_bf16_clamp() {
324        let lo = BFloat16::from_f32(0.0);
325        let hi = BFloat16::from_f32(1.0);
326        let v = BFloat16::from_f32(1.5);
327        assert!((v.clamp(lo, hi).to_f32() - 1.0).abs() < 0.01);
328        let v2 = BFloat16::from_f32(-0.5);
329        assert!((v2.clamp(lo, hi).to_f32()).abs() < 0.01);
330    }
331
332    #[test]
333    fn test_bf16_batch_convert() {
334        let f32s = vec![1.0f32, 2.0, 3.0, 4.0];
335        let bf16s = f32_to_bf16_slice(&f32s);
336        let back = bf16_to_f32_slice(&bf16s);
337        for i in 0..f32s.len() {
338            assert!((back[i] - f32s[i]).abs() < 0.05);
339        }
340    }
341
342    #[test]
343    fn test_bf16_dot() {
344        let a = f32_to_bf16_slice(&[1.0, 2.0, 3.0]);
345        let b = f32_to_bf16_slice(&[4.0, 5.0, 6.0]);
346        let result = bf16_dot(&a, &b);
347        assert!((result - 32.0).abs() < 0.5); // 1*4 + 2*5 + 3*6 = 32
348    }
349
350    #[test]
351    fn test_bf16_gemv() {
352        let a = f32_to_bf16_slice(&[1.0, 2.0, 3.0, 4.0]); // 2x2
353        let x = f32_to_bf16_slice(&[1.0, 1.0]);
354        let y = bf16_gemv(&a, &x, 2, 2);
355        assert!((y[0] - 3.0).abs() < 0.1); // 1+2
356        assert!((y[1] - 7.0).abs() < 0.1); // 3+4
357    }
358
359    #[test]
360    fn test_bf16_gemm() {
361        // 2x2 * 2x2
362        let a = f32_to_bf16_slice(&[1.0, 2.0, 3.0, 4.0]);
363        let b = f32_to_bf16_slice(&[5.0, 6.0, 7.0, 8.0]);
364        let c = bf16_gemm(&a, &b, 2, 2, 2);
365        assert!((c[0] - 19.0).abs() < 0.5); // 1*5+2*7
366        assert!((c[1] - 22.0).abs() < 0.5); // 1*6+2*8
367        assert!((c[2] - 43.0).abs() < 0.5); // 3*5+4*7
368        assert!((c[3] - 50.0).abs() < 0.5); // 3*6+4*8
369    }
370
371    #[test]
372    fn test_bf16_same_range_as_f32() {
373        // bf16 should handle very large values that fp16 cannot
374        let big = BFloat16::from_f32(1e30);
375        assert!(big.to_f32() > 1e29);
376        assert!(big.is_finite());
377
378        let small = BFloat16::from_f32(1e-30);
379        assert!(small.to_f32() > 0.0);
380        assert!(small.is_finite());
381    }
382
383    #[test]
384    fn test_bf16_display() {
385        let v = BFloat16::from_f32(3.14);
386        let s = format!("{}", v);
387        assert!(s.contains("3.1"), "Expected ~3.14, got {}", s);
388    }
389}