Skip to main content

rlx_runtime/
precision.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Precision selection for graph execution.
17//!
18//! Each backend can compile a graph at f32 (default — accurate) or f16
19//! (half precision — 2× peak FLOPs and ½ memory bandwidth on supported
20//! hardware). The IR remains dtype-agnostic; the backend decides how to
21//! materialize buffers and pick kernels.
22//!
23//! Mixed precision: f16 inference typically keeps reductions (LayerNorm
24//! mean/var, attention softmax) in f32 to avoid catastrophic accuracy
25//! loss while keeping matmul + element-wise in f16.
26
27/// Numeric precision for graph compilation.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum Precision {
30    /// Full single precision. Always supported; accurate; baseline.
31    #[default]
32    F32,
33    /// Half precision (IEEE 754 binary16). Native on Apple Silicon GPU
34    /// and many CPUs (NEON `vfmaq_f16`). 2× FLOPs / 0.5× memory vs F32.
35    /// Reductions are still computed in F32 for numerical stability.
36    F16,
37    /// Brain-float: 8-bit exponent, 7-bit mantissa. Same range as F32,
38    /// less precision. Used in many LLMs. Accelerator-dependent.
39    BF16,
40}
41
42impl Precision {
43    /// Bytes per scalar at this precision.
44    pub fn size_bytes(self) -> usize {
45        match self {
46            Precision::F32 => 4,
47            Precision::F16 | Precision::BF16 => 2,
48        }
49    }
50
51    /// Backward-compatible alias used in older code.
52    pub fn bytes(self) -> usize {
53        self.size_bytes()
54    }
55}
56
57impl std::fmt::Display for Precision {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Precision::F32 => write!(f, "f32"),
61            Precision::F16 => write!(f, "f16"),
62            Precision::BF16 => write!(f, "bf16"),
63        }
64    }
65}