Skip to main content

rlx_ir/
dtype.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Element data types for tensors.
17
18/// Scalar element type. Matches hardware-supported types.
19#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum DType {
22    F32,
23    F16,
24    BF16,
25    F64,
26    I8,
27    I16,
28    I32,
29    I64,
30    U8,
31    U32,
32    Bool,
33    /// Complex with f32 real and f32 imaginary components, stored
34    /// interleaved as `[re, im, re, im, ...]`. 8 bytes per complex
35    /// element. Element-wise ops (Add/Sub/Mul/Conj) follow the
36    /// standard complex algebra. Reverse-mode AD on this dtype is
37    /// **not yet wired** — Wirtinger conventions (∂/∂z vs ∂/∂z̄)
38    /// belong to a separate pass that knows to emit conjugate-aware
39    /// VJPs. The forward path is sufficient for AC analysis and
40    /// FFT-based workflows that don't need to differentiate through
41    /// complex math (and in fact, FFT today already encodes complex
42    /// as 2N-real-block; this dtype is the natural successor).
43    C64,
44}
45
46impl DType {
47    /// Size in bytes of one element.
48    pub const fn size_bytes(self) -> usize {
49        match self {
50            Self::Bool | Self::I8 | Self::U8 => 1,
51            Self::F16 | Self::BF16 | Self::I16 => 2,
52            Self::F32 | Self::I32 | Self::U32 => 4,
53            Self::F64 | Self::I64 | Self::C64 => 8,
54        }
55    }
56
57    pub const fn is_float(self) -> bool {
58        matches!(self, Self::F32 | Self::F16 | Self::BF16 | Self::F64)
59    }
60
61    /// True for complex-valued dtypes. Complex elementwise ops follow
62    /// standard complex algebra, distinct from the float real/imag
63    /// components (e.g. complex multiply ≠ paired-real multiply).
64    pub const fn is_complex(self) -> bool {
65        matches!(self, Self::C64)
66    }
67
68    pub const fn is_int(self) -> bool {
69        matches!(
70            self,
71            Self::I8 | Self::I16 | Self::I32 | Self::I64 | Self::U8 | Self::U32
72        )
73    }
74
75    /// Promotion rank — higher means "wider, more expressive". The
76    /// promoted dtype of a binary op is `max(rank(lhs), rank(rhs))`.
77    /// Borrowed from MAX's `dtype_promotion.py` pattern (#55 in
78    /// PLAN.md): one module owns the table; ops query it instead of
79    /// re-implementing ad-hoc rules.
80    ///
81    /// Ranks (low → high):
82    ///   0 = Bool, 1 = U8/I8, 2 = I16/BF16, 3 = F16, 4 = U32/I32,
83    ///   5 = I64, 6 = F32, 7 = F64.
84    /// Floats outrank ints of the same width (matches PyTorch /
85    /// NumPy). BF16 promotes to F32 against F16 since BF16 has
86    /// wider range but F16 has more mantissa.
87    pub const fn promotion_rank(self) -> u8 {
88        match self {
89            Self::Bool => 0,
90            Self::U8 | Self::I8 => 1,
91            Self::I16 | Self::BF16 => 2,
92            Self::F16 => 3,
93            Self::U32 | Self::I32 => 4,
94            Self::I64 => 5,
95            Self::F32 => 6,
96            Self::F64 => 7,
97            Self::C64 => 8,
98        }
99    }
100
101    /// Result dtype for a binary op between `self` and `other`.
102    /// Mixed int+float → float at least as wide as either input.
103    /// `f16 + bf16 → f32` (no clean lossless target).
104    pub fn promote(self, other: Self) -> Self {
105        if self == other {
106            return self;
107        }
108        // Special case: f16 + bf16 → f32 (their domains are too
109        // different to lose precision in either direction).
110        if matches!(
111            (self, other),
112            (Self::F16, Self::BF16) | (Self::BF16, Self::F16)
113        ) {
114            return Self::F32;
115        }
116        // Mixed int+float: bump to the smallest float that covers both.
117        let promote_int_to_float = |int: Self, float: Self| -> Self {
118            match (int, float) {
119                (_, Self::F64) => Self::F64,
120                (Self::I64, _) => Self::F64, // 64-bit int needs F64
121                (_, Self::F32) => Self::F32,
122                (_, Self::F16) | (_, Self::BF16) => Self::F32, // safe upcast
123                _ => float,
124            }
125        };
126        match (
127            self.is_int(),
128            other.is_int(),
129            self.is_float(),
130            other.is_float(),
131        ) {
132            (true, false, false, true) => promote_int_to_float(self, other),
133            (false, true, true, false) => promote_int_to_float(other, self),
134            _ => {
135                if self.promotion_rank() >= other.promotion_rank() {
136                    self
137                } else {
138                    other
139                }
140            }
141        }
142    }
143}
144
145/// Per-element semantics that don't fit into a flat `DType` enum
146/// (plan #40). Mirrors MAX's `layout/element.mojo` `Element` type:
147/// `DType` says "f8", but two FP8 variants exist (e4m3 and e5m2)
148/// with different range/precision tradeoffs. Saturation policy
149/// (clamp on overflow vs. wrap) is similarly orthogonal.
150///
151/// Today most ops only care about `dtype`; downstream quantization
152/// kernels read `subtype` and `saturating` to pick the right
153/// dequant. Building this in early prevents the "every op grew its
154/// own ad-hoc fp8 flag" mess MAX hit in v1.
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
156pub struct Element {
157    pub dtype: DType,
158    /// Subtype within `dtype` for FP8 variants etc. `Standard`
159    /// for everything else.
160    pub subtype: ElementSubtype,
161    /// Whether arithmetic saturates on overflow (true for the
162    /// quantized accumulator paths) or wraps (default).
163    pub saturating: bool,
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167pub enum ElementSubtype {
168    Standard,
169    /// FP8 e4m3 (4 exp bits, 3 mantissa) — lower range, more
170    /// precision; matches NVIDIA's "FNUZ" Hopper format.
171    Fp8E4m3,
172    /// FP8 e5m2 (5 exp bits, 2 mantissa) — wider range, less
173    /// precision; closer to bf16 in dynamic range.
174    Fp8E5m2,
175}
176
177impl Element {
178    pub const fn new(dtype: DType) -> Self {
179        Self {
180            dtype,
181            subtype: ElementSubtype::Standard,
182            saturating: false,
183        }
184    }
185    pub const fn fp8_e4m3() -> Self {
186        Self {
187            dtype: DType::U8,
188            subtype: ElementSubtype::Fp8E4m3,
189            saturating: true,
190        }
191    }
192    pub const fn fp8_e5m2() -> Self {
193        Self {
194            dtype: DType::U8,
195            subtype: ElementSubtype::Fp8E5m2,
196            saturating: true,
197        }
198    }
199    pub const fn saturating(self) -> Self {
200        Self {
201            saturating: true,
202            ..self
203        }
204    }
205}
206
207impl std::fmt::Display for DType {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        match self {
210            Self::F32 => write!(f, "f32"),
211            Self::F16 => write!(f, "f16"),
212            Self::BF16 => write!(f, "bf16"),
213            Self::F64 => write!(f, "f64"),
214            Self::I8 => write!(f, "i8"),
215            Self::I16 => write!(f, "i16"),
216            Self::I32 => write!(f, "i32"),
217            Self::I64 => write!(f, "i64"),
218            Self::U8 => write!(f, "u8"),
219            Self::U32 => write!(f, "u32"),
220            Self::Bool => write!(f, "bool"),
221            Self::C64 => write!(f, "c64"),
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn element_constructors() {
232        let f = Element::new(DType::F32);
233        assert_eq!(f.dtype, DType::F32);
234        assert_eq!(f.subtype, ElementSubtype::Standard);
235        assert!(!f.saturating);
236
237        let e4 = Element::fp8_e4m3();
238        assert_eq!(e4.subtype, ElementSubtype::Fp8E4m3);
239        assert!(e4.saturating);
240        assert_eq!(e4.dtype, DType::U8);
241
242        let s = Element::new(DType::I32).saturating();
243        assert!(s.saturating);
244        assert_eq!(s.dtype, DType::I32);
245    }
246
247    #[test]
248    fn promote_same() {
249        assert_eq!(DType::F32.promote(DType::F32), DType::F32);
250        assert_eq!(DType::I8.promote(DType::I8), DType::I8);
251    }
252
253    #[test]
254    fn promote_int_widening() {
255        assert_eq!(DType::I8.promote(DType::I16), DType::I16);
256        assert_eq!(DType::I32.promote(DType::I64), DType::I64);
257    }
258
259    #[test]
260    fn promote_int_to_float() {
261        assert_eq!(DType::I32.promote(DType::F32), DType::F32);
262        assert_eq!(DType::I64.promote(DType::F32), DType::F64);
263        assert_eq!(DType::I8.promote(DType::F16), DType::F32);
264    }
265
266    #[test]
267    fn promote_f16_bf16_goes_to_f32() {
268        assert_eq!(DType::F16.promote(DType::BF16), DType::F32);
269        assert_eq!(DType::BF16.promote(DType::F16), DType::F32);
270    }
271
272    #[test]
273    fn promote_is_commutative_for_well_defined_pairs() {
274        let pairs = [
275            (DType::F32, DType::F16),
276            (DType::I32, DType::F64),
277            (DType::Bool, DType::I8),
278        ];
279        for (a, b) in pairs {
280            assert_eq!(
281                a.promote(b),
282                b.promote(a),
283                "promote({a},{b}) should equal promote({b},{a})"
284            );
285        }
286    }
287}