1#[cfg(feature = "no-std")]
7use core::fmt;
8#[cfg(not(feature = "no-std"))]
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct F16(pub u16);
14
15#[derive(Debug, Clone, Copy, PartialEq)]
17pub struct BF16(pub u16);
18
19impl F16 {
20 pub fn from_bits(bits: u16) -> Self {
22 F16(bits)
23 }
24
25 pub fn to_bits(self) -> u16 {
27 self.0
28 }
29
30 pub fn from_f32(value: f32) -> Self {
32 let bits = value.to_bits();
34 let sign = (bits >> 31) as u16;
35 let exp = ((bits >> 23) & 0xFF) as i32;
36 let mant = bits & 0x7FFFFF;
37
38 if exp == 0 && mant == 0 {
39 F16(sign << 15)
41 } else if exp == 0xFF {
42 let new_mant = if mant == 0 { 0 } else { 0x3FF };
44 F16((sign << 15) | 0x7C00 | new_mant)
45 } else {
46 let new_exp = exp - 127 + 15;
48 if new_exp <= 0 {
49 if new_exp < -10 {
51 F16(sign << 15)
52 } else {
53 let new_mant = (mant | 0x800000) >> (14 - new_exp);
54 F16((sign << 15) | ((new_mant + 0x1000) >> 13) as u16)
55 }
56 } else if new_exp >= 31 {
57 F16((sign << 15) | 0x7C00)
59 } else {
60 let new_mant = ((mant + 0x1000) >> 13) as u16;
62 F16((sign << 15) | ((new_exp as u16) << 10) | new_mant)
63 }
64 }
65 }
66
67 pub fn to_f32(self) -> f32 {
69 let bits = self.0;
70 let sign = (bits >> 15) as u32;
71 let exp = ((bits >> 10) & 0x1F) as u32;
72 let mant = (bits & 0x3FF) as u32;
73
74 if exp == 0 && mant == 0 {
75 f32::from_bits(sign << 31)
77 } else if exp == 0 {
78 let mut new_mant = mant;
80 let mut new_exp = 0;
81 while (new_mant & 0x400) == 0 {
82 new_mant <<= 1;
83 new_exp += 1;
84 }
85 new_mant &= 0x3FF;
86 new_exp = 127 - 15 - new_exp;
87 f32::from_bits((sign << 31) | (new_exp << 23) | (new_mant << 13))
88 } else if exp == 31 {
89 let new_mant = if mant == 0 { 0 } else { 0x7FFFFF };
91 f32::from_bits((sign << 31) | 0x7F800000 | new_mant)
92 } else {
93 let new_exp = exp + 127 - 15;
95 f32::from_bits((sign << 31) | (new_exp << 23) | (mant << 13))
96 }
97 }
98
99 pub fn is_finite(self) -> bool {
101 (self.0 & 0x7C00) != 0x7C00
102 }
103
104 pub fn is_infinite(self) -> bool {
106 (self.0 & 0x7FFF) == 0x7C00
107 }
108
109 pub fn is_nan(self) -> bool {
111 (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x3FF) != 0
112 }
113}
114
115impl BF16 {
116 pub fn from_bits(bits: u16) -> Self {
118 BF16(bits)
119 }
120
121 pub fn to_bits(self) -> u16 {
123 self.0
124 }
125
126 pub fn from_f32(value: f32) -> Self {
128 let bits = value.to_bits();
130 let _truncated = (bits >> 16) as u16;
131
132 let rounding_bias = 0x7FFF + ((bits >> 16) & 1);
134 let rounded = ((bits + rounding_bias) >> 16) as u16;
135
136 BF16(rounded)
137 }
138
139 pub fn to_f32(self) -> f32 {
141 f32::from_bits((self.0 as u32) << 16)
143 }
144
145 pub fn is_finite(self) -> bool {
147 (self.0 & 0x7F80) != 0x7F80
148 }
149
150 pub fn is_infinite(self) -> bool {
152 (self.0 & 0x7FFF) == 0x7F80
153 }
154
155 pub fn is_nan(self) -> bool {
157 (self.0 & 0x7F80) == 0x7F80 && (self.0 & 0x7F) != 0
158 }
159}
160
161impl fmt::Display for F16 {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 write!(f, "{}", self.to_f32())
164 }
165}
166
167impl fmt::Display for BF16 {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 write!(f, "{}", self.to_f32())
170 }
171}
172
173pub mod simd {
175 use super::*;
176
177 pub fn f32_to_f16_slice(input: &[f32], output: &mut [F16]) {
179 assert_eq!(input.len(), output.len());
180
181 const CHUNK_SIZE: usize = 8;
183 let chunks = input.len() / CHUNK_SIZE;
184
185 for i in 0..chunks {
186 let start = i * CHUNK_SIZE;
187 let end = start + CHUNK_SIZE;
188
189 for j in start..end {
191 output[j] = F16::from_f32(input[j]);
192 }
193 }
194
195 for i in (chunks * CHUNK_SIZE)..input.len() {
197 output[i] = F16::from_f32(input[i]);
198 }
199 }
200
201 pub fn f16_to_f32_slice(input: &[F16], output: &mut [f32]) {
203 assert_eq!(input.len(), output.len());
204
205 const CHUNK_SIZE: usize = 8;
206 let chunks = input.len() / CHUNK_SIZE;
207
208 for i in 0..chunks {
209 let start = i * CHUNK_SIZE;
210 let end = start + CHUNK_SIZE;
211
212 for j in start..end {
213 output[j] = input[j].to_f32();
214 }
215 }
216
217 for i in (chunks * CHUNK_SIZE)..input.len() {
218 output[i] = input[i].to_f32();
219 }
220 }
221
222 pub fn f32_to_bf16_slice(input: &[f32], output: &mut [BF16]) {
224 assert_eq!(input.len(), output.len());
225
226 const CHUNK_SIZE: usize = 8;
227 let chunks = input.len() / CHUNK_SIZE;
228
229 for i in 0..chunks {
230 let start = i * CHUNK_SIZE;
231 let end = start + CHUNK_SIZE;
232
233 for j in start..end {
234 output[j] = BF16::from_f32(input[j]);
235 }
236 }
237
238 for i in (chunks * CHUNK_SIZE)..input.len() {
239 output[i] = BF16::from_f32(input[i]);
240 }
241 }
242
243 pub fn bf16_to_f32_slice(input: &[BF16], output: &mut [f32]) {
245 assert_eq!(input.len(), output.len());
246
247 const CHUNK_SIZE: usize = 8;
248 let chunks = input.len() / CHUNK_SIZE;
249
250 for i in 0..chunks {
251 let start = i * CHUNK_SIZE;
252 let end = start + CHUNK_SIZE;
253
254 for j in start..end {
255 output[j] = input[j].to_f32();
256 }
257 }
258
259 for i in (chunks * CHUNK_SIZE)..input.len() {
260 output[i] = input[i].to_f32();
261 }
262 }
263
264 pub fn add_f16(a: &[F16], b: &[F16], result: &mut [F16]) {
266 assert_eq!(a.len(), b.len());
267 assert_eq!(a.len(), result.len());
268
269 for i in 0..a.len() {
270 let sum = a[i].to_f32() + b[i].to_f32();
271 result[i] = F16::from_f32(sum);
272 }
273 }
274
275 pub fn mul_f16(a: &[F16], b: &[F16], result: &mut [F16]) {
277 assert_eq!(a.len(), b.len());
278 assert_eq!(a.len(), result.len());
279
280 for i in 0..a.len() {
281 let product = a[i].to_f32() * b[i].to_f32();
282 result[i] = F16::from_f32(product);
283 }
284 }
285
286 pub fn add_bf16(a: &[BF16], b: &[BF16], result: &mut [BF16]) {
288 assert_eq!(a.len(), b.len());
289 assert_eq!(a.len(), result.len());
290
291 for i in 0..a.len() {
292 let sum = a[i].to_f32() + b[i].to_f32();
293 result[i] = BF16::from_f32(sum);
294 }
295 }
296
297 pub fn mul_bf16(a: &[BF16], b: &[BF16], result: &mut [BF16]) {
299 assert_eq!(a.len(), b.len());
300 assert_eq!(a.len(), result.len());
301
302 for i in 0..a.len() {
303 let product = a[i].to_f32() * b[i].to_f32();
304 result[i] = BF16::from_f32(product);
305 }
306 }
307
308 pub fn dot_f16(a: &[F16], b: &[F16]) -> f32 {
310 assert_eq!(a.len(), b.len());
311
312 let mut sum = 0.0f32;
313 for i in 0..a.len() {
314 sum += a[i].to_f32() * b[i].to_f32();
315 }
316 sum
317 }
318
319 pub fn dot_bf16(a: &[BF16], b: &[BF16]) -> f32 {
321 assert_eq!(a.len(), b.len());
322
323 let mut sum = 0.0f32;
324 for i in 0..a.len() {
325 sum += a[i].to_f32() * b[i].to_f32();
326 }
327 sum
328 }
329
330 pub fn matmul_f16(a: &[F16], b: &[F16], c: &mut [F16], m: usize, n: usize, k: usize) {
332 assert_eq!(a.len(), m * k);
333 assert_eq!(b.len(), k * n);
334 assert_eq!(c.len(), m * n);
335
336 for i in 0..m {
337 for j in 0..n {
338 let mut sum = 0.0f32;
339 for l in 0..k {
340 sum += a[i * k + l].to_f32() * b[l * n + j].to_f32();
341 }
342 c[i * n + j] = F16::from_f32(sum);
343 }
344 }
345 }
346
347 pub fn matmul_bf16(a: &[BF16], b: &[BF16], c: &mut [BF16], m: usize, n: usize, k: usize) {
349 assert_eq!(a.len(), m * k);
350 assert_eq!(b.len(), k * n);
351 assert_eq!(c.len(), m * n);
352
353 for i in 0..m {
354 for j in 0..n {
355 let mut sum = 0.0f32;
356 for l in 0..k {
357 sum += a[i * k + l].to_f32() * b[l * n + j].to_f32();
358 }
359 c[i * n + j] = BF16::from_f32(sum);
360 }
361 }
362 }
363}
364
365pub mod constants {
367 use super::*;
368
369 pub const F16_ZERO: F16 = F16(0);
370 pub const F16_ONE: F16 = F16(0x3C00);
371 pub const F16_NEG_ONE: F16 = F16(0xBC00);
372 pub const F16_INFINITY: F16 = F16(0x7C00);
373 pub const F16_NEG_INFINITY: F16 = F16(0xFC00);
374 pub const F16_NAN: F16 = F16(0x7E00);
375 pub const F16_MAX: F16 = F16(0x7BFF);
376 pub const F16_MIN: F16 = F16(0x0400);
377 pub const F16_EPSILON: F16 = F16(0x1400);
378
379 pub const BF16_ZERO: BF16 = BF16(0);
380 pub const BF16_ONE: BF16 = BF16(0x3F80);
381 pub const BF16_NEG_ONE: BF16 = BF16(0xBF80);
382 pub const BF16_INFINITY: BF16 = BF16(0x7F80);
383 pub const BF16_NEG_INFINITY: BF16 = BF16(0xFF80);
384 pub const BF16_NAN: BF16 = BF16(0x7FC0);
385 pub const BF16_MAX: BF16 = BF16(0x7F7F);
386 pub const BF16_MIN: BF16 = BF16(0x0080);
387 pub const BF16_EPSILON: BF16 = BF16(0x3C00);
388}
389
390#[allow(non_snake_case)]
391#[cfg(all(test, not(feature = "no-std")))]
392mod tests {
393 use super::constants::*;
394 use super::*;
395
396 #[cfg(feature = "no-std")]
397 use alloc::{vec, vec::Vec};
398
399 #[test]
400 fn test_f16_conversion() {
401 let val = std::f32::consts::PI;
402 let f16_val = F16::from_f32(val);
403 let back_to_f32 = f16_val.to_f32();
404
405 assert!((val - back_to_f32).abs() < 0.01);
407 }
408
409 #[test]
410 fn test_bf16_conversion() {
411 let val = std::f32::consts::PI;
412 let bf16_val = BF16::from_f32(val);
413 let back_to_f32 = bf16_val.to_f32();
414
415 assert!((val - back_to_f32).abs() < 0.01);
417 }
418
419 #[test]
420 fn test_f16_constants() {
421 assert_eq!(F16_ZERO.to_f32(), 0.0);
422 assert_eq!(F16_ONE.to_f32(), 1.0);
423 assert_eq!(F16_NEG_ONE.to_f32(), -1.0);
424 assert!(F16_INFINITY.is_infinite());
425 assert!(F16_NAN.is_nan());
426 }
427
428 #[test]
429 fn test_bf16_constants() {
430 assert_eq!(BF16_ZERO.to_f32(), 0.0);
431 assert_eq!(BF16_ONE.to_f32(), 1.0);
432 assert_eq!(BF16_NEG_ONE.to_f32(), -1.0);
433 assert!(BF16_INFINITY.is_infinite());
434 assert!(BF16_NAN.is_nan());
435 }
436
437 #[test]
438 fn test_f16_special_values() {
439 let inf = F16::from_f32(f32::INFINITY);
440 let neg_inf = F16::from_f32(f32::NEG_INFINITY);
441 let nan = F16::from_f32(f32::NAN);
442
443 assert!(inf.is_infinite());
444 assert!(neg_inf.is_infinite());
445 assert!(nan.is_nan());
446 }
447
448 #[test]
449 fn test_bf16_special_values() {
450 let inf = BF16::from_f32(f32::INFINITY);
451 let neg_inf = BF16::from_f32(f32::NEG_INFINITY);
452 let nan = BF16::from_f32(f32::NAN);
453
454 assert!(inf.is_infinite());
455 assert!(neg_inf.is_infinite());
456 assert!(nan.is_nan());
457 }
458
459 #[test]
460 fn test_simd_f32_to_f16_conversion() {
461 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
462 let mut output = vec![F16::from_bits(0); 8];
463
464 simd::f32_to_f16_slice(&input, &mut output);
465
466 for i in 0..input.len() {
467 assert!((input[i] - output[i].to_f32()).abs() < 0.01);
468 }
469 }
470
471 #[test]
472 fn test_simd_f32_to_bf16_conversion() {
473 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
474 let mut output = vec![BF16::from_bits(0); 8];
475
476 simd::f32_to_bf16_slice(&input, &mut output);
477
478 for i in 0..input.len() {
479 assert!((input[i] - output[i].to_f32()).abs() < 0.01);
480 }
481 }
482
483 #[test]
484 fn test_f16_arithmetic() {
485 let a = vec![F16::from_f32(1.0), F16::from_f32(2.0), F16::from_f32(3.0)];
486 let b = vec![F16::from_f32(4.0), F16::from_f32(5.0), F16::from_f32(6.0)];
487 let mut result = vec![F16::from_bits(0); 3];
488
489 simd::add_f16(&a, &b, &mut result);
490
491 let expected = [5.0, 7.0, 9.0];
492 for i in 0..3 {
493 assert!((result[i].to_f32() - expected[i]).abs() < 0.01);
494 }
495 }
496
497 #[test]
498 fn test_bf16_arithmetic() {
499 let a = vec![
500 BF16::from_f32(1.0),
501 BF16::from_f32(2.0),
502 BF16::from_f32(3.0),
503 ];
504 let b = vec![
505 BF16::from_f32(4.0),
506 BF16::from_f32(5.0),
507 BF16::from_f32(6.0),
508 ];
509 let mut result = vec![BF16::from_bits(0); 3];
510
511 simd::add_bf16(&a, &b, &mut result);
512
513 let expected = [5.0, 7.0, 9.0];
514 for i in 0..3 {
515 assert!((result[i].to_f32() - expected[i]).abs() < 0.01);
516 }
517 }
518
519 #[test]
520 fn test_f16_dot_product() {
521 let a = vec![F16::from_f32(1.0), F16::from_f32(2.0), F16::from_f32(3.0)];
522 let b = vec![F16::from_f32(4.0), F16::from_f32(5.0), F16::from_f32(6.0)];
523
524 let result = simd::dot_f16(&a, &b);
525 let expected = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; assert!((result - expected).abs() < 0.1);
528 }
529
530 #[test]
531 fn test_bf16_dot_product() {
532 let a = vec![
533 BF16::from_f32(1.0),
534 BF16::from_f32(2.0),
535 BF16::from_f32(3.0),
536 ];
537 let b = vec![
538 BF16::from_f32(4.0),
539 BF16::from_f32(5.0),
540 BF16::from_f32(6.0),
541 ];
542
543 let result = simd::dot_bf16(&a, &b);
544 let expected = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; assert!((result - expected).abs() < 0.1);
547 }
548
549 #[test]
550 fn test_f16_matrix_multiplication() {
551 let a = vec![
553 F16::from_f32(1.0),
554 F16::from_f32(2.0),
555 F16::from_f32(3.0),
556 F16::from_f32(4.0),
557 ];
558 let b = vec![
559 F16::from_f32(5.0),
560 F16::from_f32(6.0),
561 F16::from_f32(7.0),
562 F16::from_f32(8.0),
563 ];
564 let mut c = vec![F16::from_bits(0); 4];
565
566 simd::matmul_f16(&a, &b, &mut c, 2, 2, 2);
567
568 let expected = [19.0, 22.0, 43.0, 50.0];
570 for i in 0..4 {
571 assert!((c[i].to_f32() - expected[i]).abs() < 0.1);
572 }
573 }
574
575 #[test]
576 fn test_bf16_matrix_multiplication() {
577 let a = vec![
579 BF16::from_f32(1.0),
580 BF16::from_f32(2.0),
581 BF16::from_f32(3.0),
582 BF16::from_f32(4.0),
583 ];
584 let b = vec![
585 BF16::from_f32(5.0),
586 BF16::from_f32(6.0),
587 BF16::from_f32(7.0),
588 BF16::from_f32(8.0),
589 ];
590 let mut c = vec![BF16::from_bits(0); 4];
591
592 simd::matmul_bf16(&a, &b, &mut c, 2, 2, 2);
593
594 let expected = [19.0, 22.0, 43.0, 50.0];
596 for i in 0..4 {
597 assert!((c[i].to_f32() - expected[i]).abs() < 0.1);
598 }
599 }
600
601 #[test]
602 fn test_large_vector_conversion() {
603 let size = 1024;
604 let input: Vec<f32> = (0..size).map(|i| i as f32 * 0.1).collect();
605 let mut f16_output = vec![F16::from_bits(0); size];
606 let mut bf16_output = vec![BF16::from_bits(0); size];
607
608 simd::f32_to_f16_slice(&input, &mut f16_output);
609 simd::f32_to_bf16_slice(&input, &mut bf16_output);
610
611 for i in 0..size {
612 let f16_error = (input[i] - f16_output[i].to_f32()).abs();
613 let bf16_error = (input[i] - bf16_output[i].to_f32()).abs();
614
615 let tolerance = if input[i].abs() > 1.0 {
617 input[i].abs() * 0.01 } else {
619 0.01 };
621
622 assert!(
623 f16_error < tolerance,
624 "F16 error {:.6} > tolerance {:.6} for input {:.6}",
625 f16_error,
626 tolerance,
627 input[i]
628 );
629 assert!(
630 bf16_error < tolerance,
631 "BF16 error {:.6} > tolerance {:.6} for input {:.6}",
632 bf16_error,
633 tolerance,
634 input[i]
635 );
636 }
637 }
638
639 #[test]
640 fn test_precision_comparison() {
641 let test_values = vec![
642 0.0,
643 1.0,
644 -1.0,
645 0.5,
646 -0.5,
647 std::f32::consts::PI,
648 std::f32::consts::E,
649 std::f32::consts::SQRT_2,
650 1.73205,
651 0.1,
652 0.01,
653 0.001,
654 0.0001,
655 ];
656
657 for &val in &test_values {
658 let f16_val = F16::from_f32(val);
659 let bf16_val = BF16::from_f32(val);
660
661 let f16_error = (val - f16_val.to_f32()).abs();
662 let bf16_error = (val - bf16_val.to_f32()).abs();
663
664 assert!(f16_error < 0.01 || val.abs() < 0.01);
666 assert!(bf16_error < 0.01 || val.abs() < 0.01);
667 }
668 }
669}