rlx_ir/dtype.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Element data types for tensors.
17
18/// Scalar element type. Matches hardware-supported types.
19#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum DType {
22 F32,
23 F16,
24 BF16,
25 F64,
26 I8,
27 I16,
28 I32,
29 I64,
30 U8,
31 U32,
32 Bool,
33 /// Complex with f32 real and f32 imaginary components, stored
34 /// interleaved as `[re, im, re, im, ...]`. 8 bytes per complex
35 /// element. Element-wise ops (Add/Sub/Mul/Conj) follow the
36 /// standard complex algebra. Reverse-mode AD on this dtype is
37 /// **not yet wired** — Wirtinger conventions (∂/∂z vs ∂/∂z̄)
38 /// belong to a separate pass that knows to emit conjugate-aware
39 /// VJPs. The forward path is sufficient for AC analysis and
40 /// FFT-based workflows that don't need to differentiate through
41 /// complex math (and in fact, FFT today already encodes complex
42 /// as 2N-real-block; this dtype is the natural successor).
43 C64,
44}
45
46impl DType {
47 /// Size in bytes of one element.
48 pub const fn size_bytes(self) -> usize {
49 match self {
50 Self::Bool | Self::I8 | Self::U8 => 1,
51 Self::F16 | Self::BF16 | Self::I16 => 2,
52 Self::F32 | Self::I32 | Self::U32 => 4,
53 Self::F64 | Self::I64 | Self::C64 => 8,
54 }
55 }
56
57 pub const fn is_float(self) -> bool {
58 matches!(self, Self::F32 | Self::F16 | Self::BF16 | Self::F64)
59 }
60
61 /// True for complex-valued dtypes. Complex elementwise ops follow
62 /// standard complex algebra, distinct from the float real/imag
63 /// components (e.g. complex multiply ≠ paired-real multiply).
64 pub const fn is_complex(self) -> bool {
65 matches!(self, Self::C64)
66 }
67
68 pub const fn is_int(self) -> bool {
69 matches!(
70 self,
71 Self::I8 | Self::I16 | Self::I32 | Self::I64 | Self::U8 | Self::U32
72 )
73 }
74
75 /// Promotion rank — higher means "wider, more expressive". The
76 /// promoted dtype of a binary op is `max(rank(lhs), rank(rhs))`.
77 /// Borrowed from MAX's `dtype_promotion.py` pattern (#55 in
78 /// PLAN.md): one module owns the table; ops query it instead of
79 /// re-implementing ad-hoc rules.
80 ///
81 /// Ranks (low → high):
82 /// 0 = Bool, 1 = U8/I8, 2 = I16/BF16, 3 = F16, 4 = U32/I32,
83 /// 5 = I64, 6 = F32, 7 = F64.
84 /// Floats outrank ints of the same width (matches PyTorch /
85 /// NumPy). BF16 promotes to F32 against F16 since BF16 has
86 /// wider range but F16 has more mantissa.
87 pub const fn promotion_rank(self) -> u8 {
88 match self {
89 Self::Bool => 0,
90 Self::U8 | Self::I8 => 1,
91 Self::I16 | Self::BF16 => 2,
92 Self::F16 => 3,
93 Self::U32 | Self::I32 => 4,
94 Self::I64 => 5,
95 Self::F32 => 6,
96 Self::F64 => 7,
97 Self::C64 => 8,
98 }
99 }
100
101 /// Result dtype for a binary op between `self` and `other`.
102 /// Mixed int+float → float at least as wide as either input.
103 /// `f16 + bf16 → f32` (no clean lossless target).
104 pub fn promote(self, other: Self) -> Self {
105 if self == other {
106 return self;
107 }
108 // Special case: f16 + bf16 → f32 (their domains are too
109 // different to lose precision in either direction).
110 if matches!(
111 (self, other),
112 (Self::F16, Self::BF16) | (Self::BF16, Self::F16)
113 ) {
114 return Self::F32;
115 }
116 // Mixed int+float: bump to the smallest float that covers both.
117 let promote_int_to_float = |int: Self, float: Self| -> Self {
118 match (int, float) {
119 (_, Self::F64) => Self::F64,
120 (Self::I64, _) => Self::F64, // 64-bit int needs F64
121 (_, Self::F32) => Self::F32,
122 (_, Self::F16) | (_, Self::BF16) => Self::F32, // safe upcast
123 _ => float,
124 }
125 };
126 match (
127 self.is_int(),
128 other.is_int(),
129 self.is_float(),
130 other.is_float(),
131 ) {
132 (true, false, false, true) => promote_int_to_float(self, other),
133 (false, true, true, false) => promote_int_to_float(other, self),
134 _ => {
135 if self.promotion_rank() >= other.promotion_rank() {
136 self
137 } else {
138 other
139 }
140 }
141 }
142 }
143}
144
145/// Per-element semantics that don't fit into a flat `DType` enum
146/// (plan #40). Mirrors MAX's `layout/element.mojo` `Element` type:
147/// `DType` says "f8", but two FP8 variants exist (e4m3 and e5m2)
148/// with different range/precision tradeoffs. Saturation policy
149/// (clamp on overflow vs. wrap) is similarly orthogonal.
150///
151/// Today most ops only care about `dtype`; downstream quantization
152/// kernels read `subtype` and `saturating` to pick the right
153/// dequant. Building this in early prevents the "every op grew its
154/// own ad-hoc fp8 flag" mess MAX hit in v1.
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
156pub struct Element {
157 pub dtype: DType,
158 /// Subtype within `dtype` for FP8 variants etc. `Standard`
159 /// for everything else.
160 pub subtype: ElementSubtype,
161 /// Whether arithmetic saturates on overflow (true for the
162 /// quantized accumulator paths) or wraps (default).
163 pub saturating: bool,
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167pub enum ElementSubtype {
168 Standard,
169 /// FP8 e4m3 (4 exp bits, 3 mantissa) — lower range, more
170 /// precision; matches NVIDIA's "FNUZ" Hopper format.
171 Fp8E4m3,
172 /// FP8 e5m2 (5 exp bits, 2 mantissa) — wider range, less
173 /// precision; closer to bf16 in dynamic range.
174 Fp8E5m2,
175}
176
177impl Element {
178 pub const fn new(dtype: DType) -> Self {
179 Self {
180 dtype,
181 subtype: ElementSubtype::Standard,
182 saturating: false,
183 }
184 }
185 pub const fn fp8_e4m3() -> Self {
186 Self {
187 dtype: DType::U8,
188 subtype: ElementSubtype::Fp8E4m3,
189 saturating: true,
190 }
191 }
192 pub const fn fp8_e5m2() -> Self {
193 Self {
194 dtype: DType::U8,
195 subtype: ElementSubtype::Fp8E5m2,
196 saturating: true,
197 }
198 }
199 pub const fn saturating(self) -> Self {
200 Self {
201 saturating: true,
202 ..self
203 }
204 }
205}
206
207impl std::fmt::Display for DType {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 match self {
210 Self::F32 => write!(f, "f32"),
211 Self::F16 => write!(f, "f16"),
212 Self::BF16 => write!(f, "bf16"),
213 Self::F64 => write!(f, "f64"),
214 Self::I8 => write!(f, "i8"),
215 Self::I16 => write!(f, "i16"),
216 Self::I32 => write!(f, "i32"),
217 Self::I64 => write!(f, "i64"),
218 Self::U8 => write!(f, "u8"),
219 Self::U32 => write!(f, "u32"),
220 Self::Bool => write!(f, "bool"),
221 Self::C64 => write!(f, "c64"),
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn element_constructors() {
232 let f = Element::new(DType::F32);
233 assert_eq!(f.dtype, DType::F32);
234 assert_eq!(f.subtype, ElementSubtype::Standard);
235 assert!(!f.saturating);
236
237 let e4 = Element::fp8_e4m3();
238 assert_eq!(e4.subtype, ElementSubtype::Fp8E4m3);
239 assert!(e4.saturating);
240 assert_eq!(e4.dtype, DType::U8);
241
242 let s = Element::new(DType::I32).saturating();
243 assert!(s.saturating);
244 assert_eq!(s.dtype, DType::I32);
245 }
246
247 #[test]
248 fn promote_same() {
249 assert_eq!(DType::F32.promote(DType::F32), DType::F32);
250 assert_eq!(DType::I8.promote(DType::I8), DType::I8);
251 }
252
253 #[test]
254 fn promote_int_widening() {
255 assert_eq!(DType::I8.promote(DType::I16), DType::I16);
256 assert_eq!(DType::I32.promote(DType::I64), DType::I64);
257 }
258
259 #[test]
260 fn promote_int_to_float() {
261 assert_eq!(DType::I32.promote(DType::F32), DType::F32);
262 assert_eq!(DType::I64.promote(DType::F32), DType::F64);
263 assert_eq!(DType::I8.promote(DType::F16), DType::F32);
264 }
265
266 #[test]
267 fn promote_f16_bf16_goes_to_f32() {
268 assert_eq!(DType::F16.promote(DType::BF16), DType::F32);
269 assert_eq!(DType::BF16.promote(DType::F16), DType::F32);
270 }
271
272 #[test]
273 fn promote_is_commutative_for_well_defined_pairs() {
274 let pairs = [
275 (DType::F32, DType::F16),
276 (DType::I32, DType::F64),
277 (DType::Bool, DType::I8),
278 ];
279 for (a, b) in pairs {
280 assert_eq!(
281 a.promote(b),
282 b.promote(a),
283 "promote({a},{b}) should equal promote({b},{a})"
284 );
285 }
286 }
287}