cuda_rust_wasm/runtime/
bfloat16.rs1use std::fmt;
11use std::ops::{Add, Sub, Mul, Div, Neg};
12
13#[derive(Clone, Copy, PartialEq, Eq, Hash)]
19pub struct BFloat16 {
20 bits: u16,
21}
22
23impl BFloat16 {
24 pub const ZERO: Self = Self { bits: 0x0000 };
25 pub const ONE: Self = Self { bits: 0x3F80 };
26 pub const NEG_ONE: Self = Self { bits: 0xBF80 };
27 pub const INFINITY: Self = Self { bits: 0x7F80 };
28 pub const NEG_INFINITY: Self = Self { bits: 0xFF80 };
29 pub const NAN: Self = Self { bits: 0x7FC0 };
30 pub const MAX: Self = Self { bits: 0x7F7F }; pub const MIN_POSITIVE: Self = Self { bits: 0x0080 }; pub const EPSILON: Self = Self { bits: 0x3C00 }; pub fn from_bits(bits: u16) -> Self {
36 Self { bits }
37 }
38
39 pub fn to_bits(self) -> u16 {
41 self.bits
42 }
43
44 pub fn from_f32(value: f32) -> Self {
46 let bits = value.to_bits();
47 let round_bit = (bits >> 15) & 1;
49 let sticky = if bits & 0x7FFF != 0 { 1u32 } else { 0 };
50 let lsb = (bits >> 16) & 1;
51
52 let rounded = (bits >> 16) + (round_bit & (sticky | lsb));
54
55 if (rounded & 0x7F80) == 0x7F80 && (bits & 0x7F800000) != 0x7F800000 {
57 Self { bits: ((bits >> 16) & 0xFF80) as u16 | 0x7F }
59 } else {
60 Self { bits: rounded as u16 }
61 }
62 }
63
64 pub fn to_f32(self) -> f32 {
66 f32::from_bits((self.bits as u32) << 16)
67 }
68
69 pub fn is_nan(self) -> bool {
71 (self.bits & 0x7F80) == 0x7F80 && (self.bits & 0x007F) != 0
72 }
73
74 pub fn is_infinite(self) -> bool {
76 (self.bits & 0x7FFF) == 0x7F80
77 }
78
79 pub fn is_finite(self) -> bool {
81 (self.bits & 0x7F80) != 0x7F80
82 }
83
84 pub fn is_zero(self) -> bool {
86 (self.bits & 0x7FFF) == 0
87 }
88
89 pub fn is_sign_negative(self) -> bool {
91 self.bits & 0x8000 != 0
92 }
93
94 pub fn abs(self) -> Self {
96 Self { bits: self.bits & 0x7FFF }
97 }
98
99 pub fn fma(a: BFloat16, b: BFloat16, c: BFloat16) -> BFloat16 {
101 BFloat16::from_f32(a.to_f32().mul_add(b.to_f32(), c.to_f32()))
102 }
103
104 pub fn sqrt(self) -> Self {
106 BFloat16::from_f32(self.to_f32().sqrt())
107 }
108
109 pub fn recip(self) -> Self {
111 BFloat16::from_f32(1.0 / self.to_f32())
112 }
113
114 pub fn min(self, other: Self) -> Self {
116 if self.is_nan() || other.is_nan() {
117 return Self::NAN;
118 }
119 if self.to_f32() <= other.to_f32() { self } else { other }
120 }
121
122 pub fn max(self, other: Self) -> Self {
124 if self.is_nan() || other.is_nan() {
125 return Self::NAN;
126 }
127 if self.to_f32() >= other.to_f32() { self } else { other }
128 }
129
130 pub fn clamp(self, lo: Self, hi: Self) -> Self {
132 self.max(lo).min(hi)
133 }
134}
135
136impl Add for BFloat16 {
139 type Output = Self;
140 fn add(self, rhs: Self) -> Self {
141 BFloat16::from_f32(self.to_f32() + rhs.to_f32())
142 }
143}
144
145impl Sub for BFloat16 {
146 type Output = Self;
147 fn sub(self, rhs: Self) -> Self {
148 BFloat16::from_f32(self.to_f32() - rhs.to_f32())
149 }
150}
151
152impl Mul for BFloat16 {
153 type Output = Self;
154 fn mul(self, rhs: Self) -> Self {
155 BFloat16::from_f32(self.to_f32() * rhs.to_f32())
156 }
157}
158
159impl Div for BFloat16 {
160 type Output = Self;
161 fn div(self, rhs: Self) -> Self {
162 BFloat16::from_f32(self.to_f32() / rhs.to_f32())
163 }
164}
165
166impl Neg for BFloat16 {
167 type Output = Self;
168 fn neg(self) -> Self {
169 Self { bits: self.bits ^ 0x8000 }
170 }
171}
172
173impl PartialOrd for BFloat16 {
174 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
175 if self.is_nan() || other.is_nan() {
176 return None;
177 }
178 self.to_f32().partial_cmp(&other.to_f32())
179 }
180}
181
182impl fmt::Debug for BFloat16 {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 write!(f, "bf16({:.4})", self.to_f32())
185 }
186}
187
188impl fmt::Display for BFloat16 {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 write!(f, "{:.4}", self.to_f32())
191 }
192}
193
194impl From<f32> for BFloat16 {
195 fn from(v: f32) -> Self { BFloat16::from_f32(v) }
196}
197
198impl From<BFloat16> for f32 {
199 fn from(v: BFloat16) -> f32 { v.to_f32() }
200}
201
202pub fn f32_to_bf16_slice(input: &[f32]) -> Vec<BFloat16> {
206 input.iter().map(|&v| BFloat16::from_f32(v)).collect()
207}
208
209pub fn bf16_to_f32_slice(input: &[BFloat16]) -> Vec<f32> {
211 input.iter().map(|v| v.to_f32()).collect()
212}
213
214pub fn bf16_dot(a: &[BFloat16], b: &[BFloat16]) -> f32 {
216 a.iter().zip(b.iter()).map(|(x, y)| x.to_f32() * y.to_f32()).sum()
217}
218
219pub fn bf16_gemv(a: &[BFloat16], x: &[BFloat16], rows: usize, cols: usize) -> Vec<f32> {
222 (0..rows).map(|r| {
223 let row_start = r * cols;
224 (0..cols).map(|c| {
225 a[row_start + c].to_f32() * x[c].to_f32()
226 }).sum()
227 }).collect()
228}
229
230pub fn bf16_gemm(a: &[BFloat16], b: &[BFloat16], m: usize, k: usize, n: usize) -> Vec<f32> {
233 let mut c = vec![0.0f32; m * n];
234 for i in 0..m {
235 for p in 0..k {
236 let a_val = a[i * k + p].to_f32();
237 for j in 0..n {
238 c[i * n + j] += a_val * b[p * n + j].to_f32();
239 }
240 }
241 }
242 c
243}
244
245#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_bf16_roundtrip() {
253 let values = [0.0f32, 1.0, -1.0, 0.5, 100.0, -0.125, 3.14];
254 for &v in &values {
255 let bf = BFloat16::from_f32(v);
256 let back = bf.to_f32();
257 assert!((back - v).abs() < 0.05, "Roundtrip failed for {}: got {}", v, back);
258 }
259 }
260
261 #[test]
262 fn test_bf16_constants() {
263 assert_eq!(BFloat16::ZERO.to_f32(), 0.0);
264 assert_eq!(BFloat16::ONE.to_f32(), 1.0);
265 assert_eq!(BFloat16::NEG_ONE.to_f32(), -1.0);
266 assert!(BFloat16::INFINITY.is_infinite());
267 assert!(BFloat16::NAN.is_nan());
268 assert!(BFloat16::MAX.to_f32() > 1e38);
269 }
270
271 #[test]
272 fn test_bf16_arithmetic() {
273 let a = BFloat16::from_f32(2.0);
274 let b = BFloat16::from_f32(3.0);
275 assert!((a + b).to_f32() - 5.0 < 0.1);
276 assert!((a - b).to_f32() - (-1.0) < 0.1);
277 assert!((a * b).to_f32() - 6.0 < 0.1);
278 assert!(((a / b).to_f32() - 0.6667).abs() < 0.02);
279 }
280
281 #[test]
282 fn test_bf16_neg() {
283 let a = BFloat16::from_f32(42.0);
284 assert!((-a).to_f32() < 0.0);
285 assert!(((-a).to_f32() + 42.0).abs() < 0.5);
286 }
287
288 #[test]
289 fn test_bf16_comparison() {
290 let a = BFloat16::from_f32(1.0);
291 let b = BFloat16::from_f32(2.0);
292 assert!(a < b);
293 assert!(b > a);
294 assert!(BFloat16::NAN.partial_cmp(&a).is_none());
295 }
296
297 #[test]
298 fn test_bf16_special_values() {
299 assert!(BFloat16::NAN.is_nan());
300 assert!(!BFloat16::NAN.is_finite());
301 assert!(BFloat16::INFINITY.is_infinite());
302 assert!(!BFloat16::INFINITY.is_finite());
303 assert!(BFloat16::ZERO.is_zero());
304 assert!(BFloat16::from_bits(0x8000).is_zero()); }
306
307 #[test]
308 fn test_bf16_fma() {
309 let a = BFloat16::from_f32(2.0);
310 let b = BFloat16::from_f32(3.0);
311 let c = BFloat16::from_f32(1.0);
312 let result = BFloat16::fma(a, b, c);
313 assert!((result.to_f32() - 7.0).abs() < 0.1);
314 }
315
316 #[test]
317 fn test_bf16_sqrt() {
318 let a = BFloat16::from_f32(4.0);
319 assert!((a.sqrt().to_f32() - 2.0).abs() < 0.05);
320 }
321
322 #[test]
323 fn test_bf16_clamp() {
324 let lo = BFloat16::from_f32(0.0);
325 let hi = BFloat16::from_f32(1.0);
326 let v = BFloat16::from_f32(1.5);
327 assert!((v.clamp(lo, hi).to_f32() - 1.0).abs() < 0.01);
328 let v2 = BFloat16::from_f32(-0.5);
329 assert!((v2.clamp(lo, hi).to_f32()).abs() < 0.01);
330 }
331
332 #[test]
333 fn test_bf16_batch_convert() {
334 let f32s = vec![1.0f32, 2.0, 3.0, 4.0];
335 let bf16s = f32_to_bf16_slice(&f32s);
336 let back = bf16_to_f32_slice(&bf16s);
337 for i in 0..f32s.len() {
338 assert!((back[i] - f32s[i]).abs() < 0.05);
339 }
340 }
341
342 #[test]
343 fn test_bf16_dot() {
344 let a = f32_to_bf16_slice(&[1.0, 2.0, 3.0]);
345 let b = f32_to_bf16_slice(&[4.0, 5.0, 6.0]);
346 let result = bf16_dot(&a, &b);
347 assert!((result - 32.0).abs() < 0.5); }
349
350 #[test]
351 fn test_bf16_gemv() {
352 let a = f32_to_bf16_slice(&[1.0, 2.0, 3.0, 4.0]); let x = f32_to_bf16_slice(&[1.0, 1.0]);
354 let y = bf16_gemv(&a, &x, 2, 2);
355 assert!((y[0] - 3.0).abs() < 0.1); assert!((y[1] - 7.0).abs() < 0.1); }
358
359 #[test]
360 fn test_bf16_gemm() {
361 let a = f32_to_bf16_slice(&[1.0, 2.0, 3.0, 4.0]);
363 let b = f32_to_bf16_slice(&[5.0, 6.0, 7.0, 8.0]);
364 let c = bf16_gemm(&a, &b, 2, 2, 2);
365 assert!((c[0] - 19.0).abs() < 0.5); assert!((c[1] - 22.0).abs() < 0.5); assert!((c[2] - 43.0).abs() < 0.5); assert!((c[3] - 50.0).abs() < 0.5); }
370
371 #[test]
372 fn test_bf16_same_range_as_f32() {
373 let big = BFloat16::from_f32(1e30);
375 assert!(big.to_f32() > 1e29);
376 assert!(big.is_finite());
377
378 let small = BFloat16::from_f32(1e-30);
379 assert!(small.to_f32() > 0.0);
380 assert!(small.is_finite());
381 }
382
383 #[test]
384 fn test_bf16_display() {
385 let v = BFloat16::from_f32(3.14);
386 let s = format!("{}", v);
387 assert!(s.contains("3.1"), "Expected ~3.14, got {}", s);
388 }
389}