rlx_runtime/precision.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Precision selection for graph execution.
17//!
18//! Each backend can compile a graph at f32 (default — accurate) or f16
19//! (half precision — 2× peak FLOPs and ½ memory bandwidth on supported
20//! hardware). The IR remains dtype-agnostic; the backend decides how to
21//! materialize buffers and pick kernels.
22//!
23//! Mixed precision: f16 inference typically keeps reductions (LayerNorm
24//! mean/var, attention softmax) in f32 to avoid catastrophic accuracy
25//! loss while keeping matmul + element-wise in f16.
26
27/// Numeric precision for graph compilation.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum Precision {
30 /// Full single precision. Always supported; accurate; baseline.
31 #[default]
32 F32,
33 /// Half precision (IEEE 754 binary16). Native on Apple Silicon GPU
34 /// and many CPUs (NEON `vfmaq_f16`). 2× FLOPs / 0.5× memory vs F32.
35 /// Reductions are still computed in F32 for numerical stability.
36 F16,
37 /// Brain-float: 8-bit exponent, 7-bit mantissa. Same range as F32,
38 /// less precision. Used in many LLMs. Accelerator-dependent.
39 BF16,
40}
41
42impl Precision {
43 /// Bytes per scalar at this precision.
44 pub fn size_bytes(self) -> usize {
45 match self {
46 Precision::F32 => 4,
47 Precision::F16 | Precision::BF16 => 2,
48 }
49 }
50
51 /// Backward-compatible alias used in older code.
52 pub fn bytes(self) -> usize {
53 self.size_bytes()
54 }
55}
56
57impl std::fmt::Display for Precision {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Precision::F32 => write!(f, "f32"),
61 Precision::F16 => write!(f, "f16"),
62 Precision::BF16 => write!(f, "bf16"),
63 }
64 }
65}