1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
//! FP8 matrix multiplication operations trait.
//!
//! FP8 matmul differs from standard matmul in two key ways:
//! 1. Per-tensor scale factors compensate for the limited dynamic range of FP8
//! 2. Accumulation is always in FP32 for numerical accuracy
//!
//! The output dtype can differ from input dtype (typically F32, F16, or BF16).
use crateDType;
use crateResult;
use crateRuntime;
use crateTensor;
/// FP8 matrix multiplication operations with per-tensor scaling.
///
/// FP8 GEMM computes: `output = (scale_a * A) @ (scale_b * B)` where A and B are
/// FP8 tensors, arithmetic is performed in FP32, and the result is cast to `out_dtype`.
///
/// # Scale Factors
///
/// FP8 has very limited dynamic range (~[-448, 448] for E4M3, ~[-57344, 57344] for E5M2).
/// Per-tensor scale factors map the original tensor range into the FP8 representable range:
///
/// ```text
/// quantize: fp8_tensor = original_tensor / scale
/// dequantize: original_tensor = fp8_tensor * scale
/// matmul: C = (A * scale_a) @ (B * scale_b) = scale_a * scale_b * (A_fp8 @ B_fp8)
/// ```
///
/// # Use Cases
///
/// - `fp8_matmul`: E4M3 x E4M3 — forward pass (weights and activations)
/// - `fp8_matmul_e5m2`: E5M2 x E4M3 — backward pass (gradients x weights)