Skip to main content

rlx_ir/
dtype.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Element data types for tensors.
17
18/// Scalar element type. Matches hardware-supported types.
19#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum DType {
22    F32,
23    F16,
24    BF16,
25    F64,
26    I8,
27    I16,
28    I32,
29    I64,
30    U8,
31    U32,
32    Bool,
33    /// Complex with f32 real and f32 imaginary components, stored
34    /// interleaved as `[re, im, re, im, ...]`. 8 bytes per complex
35    /// element. Element-wise ops (Add/Sub/Mul/Conj) follow the
36    /// standard complex algebra. Reverse-mode AD on this dtype is
37    /// **not yet wired** — Wirtinger conventions (∂/∂z vs ∂/∂z̄)
38    /// belong to a separate pass that knows to emit conjugate-aware
39    /// VJPs. The forward path is sufficient for AC analysis and
40    /// FFT-based workflows that don't need to differentiate through
41    /// complex math (and in fact, FFT today already encodes complex
42    /// as 2N-real-block; this dtype is the natural successor).
43    C64,
44}
45
46impl DType {
47    /// Size in bytes of one element.
48    pub const fn size_bytes(self) -> usize {
49        match self {
50            Self::Bool | Self::I8 | Self::U8 => 1,
51            Self::F16 | Self::BF16 | Self::I16 => 2,
52            Self::F32 | Self::I32 | Self::U32 => 4,
53            Self::F64 | Self::I64 | Self::C64 => 8,
54        }
55    }
56
57    pub const fn is_float(self) -> bool {
58        matches!(self, Self::F32 | Self::F16 | Self::BF16 | Self::F64)
59    }
60
61    /// True for complex-valued dtypes. Complex elementwise ops follow
62    /// standard complex algebra, distinct from the float real/imag
63    /// components (e.g. complex multiply ≠ paired-real multiply).
64    pub const fn is_complex(self) -> bool {
65        matches!(self, Self::C64)
66    }
67
68    pub const fn is_int(self) -> bool {
69        matches!(
70            self,
71            Self::I8 | Self::I16 | Self::I32 | Self::I64 | Self::U8 | Self::U32
72        )
73    }
74
75    /// Promotion rank — higher means "wider, more expressive". The
76    /// promoted dtype of a binary op is `max(rank(lhs), rank(rhs))`.
77    /// Borrowed from MAX's `dtype_promotion.py` pattern (#55 in
78    /// PLAN.md): one module owns the table; ops query it instead of
79    /// re-implementing ad-hoc rules.
80    ///
81    /// Ranks (low → high):
82    ///   0 = Bool, 1 = U8/I8, 2 = I16/BF16, 3 = F16, 4 = U32/I32,
83    ///   5 = I64, 6 = F32, 7 = F64.
84    /// Floats outrank ints of the same width (matches PyTorch /
85    /// NumPy). BF16 promotes to F32 against F16 since BF16 has
86    /// wider range but F16 has more mantissa.
87    pub const fn promotion_rank(self) -> u8 {
88        match self {
89            Self::Bool => 0,
90            Self::U8 | Self::I8 => 1,
91            Self::I16 | Self::BF16 => 2,
92            Self::F16 => 3,
93            Self::U32 | Self::I32 => 4,
94            Self::I64 => 5,
95            Self::F32 => 6,
96            Self::F64 => 7,
97            Self::C64 => 8,
98        }
99    }
100
101    /// Result dtype for a binary op between `self` and `other`.
102    /// Mixed int+float → float at least as wide as either input.
103    /// `f16 + bf16 → f32` (no clean lossless target).
104    pub fn promote(self, other: Self) -> Self {
105        if self == other {
106            return self;
107        }
108        // Special case: f16 + bf16 → f32 (their domains are too
109        // different to lose precision in either direction).
110        if matches!(
111            (self, other),
112            (Self::F16, Self::BF16) | (Self::BF16, Self::F16)
113        ) {
114            return Self::F32;
115        }
116        // Mixed int+float: bump to the smallest float that covers both.
117        let promote_int_to_float = |int: Self, float: Self| -> Self {
118            match (int, float) {
119                (_, Self::F64) => Self::F64,
120                (Self::I64, _) => Self::F64, // 64-bit int needs F64
121                (_, Self::F32) => Self::F32,
122                (_, Self::F16) | (_, Self::BF16) => Self::F32, // safe upcast
123                _ => float,
124            }
125        };
126        match (
127            self.is_int(),
128            other.is_int(),
129            self.is_float(),
130            other.is_float(),
131        ) {
132            (true, false, false, true) => promote_int_to_float(self, other),
133            (false, true, true, false) => promote_int_to_float(other, self),
134            _ => {
135                if self.promotion_rank() >= other.promotion_rank() {
136                    self
137                } else {
138                    other
139                }
140            }
141        }
142    }
143}
144
145fn integral_scalar(value: f64, name: &str) -> Result<i64, String> {
146    if !value.is_finite() {
147        return Err(format!(
148            "constant value {value} is not finite for dtype {name}"
149        ));
150    }
151    if value.fract() != 0.0 {
152        return Err(format!(
153            "constant value {value} must be integral for dtype {name}"
154        ));
155    }
156    Ok(value as i64)
157}
158
159/// Encode a scalar as little-endian bytes for [`crate::op::Op::Constant`].
160pub fn scalar_constant_bytes(value: f64, dtype: DType) -> Result<Vec<u8>, String> {
161    let out_of_range =
162        |name: &str| format!("constant value {value} is out of range for dtype {name}");
163    match dtype {
164        DType::F32 => Ok((value as f32).to_le_bytes().to_vec()),
165        DType::F64 => Ok(value.to_le_bytes().to_vec()),
166        DType::I8 => {
167            let v = integral_scalar(value, "i8")?;
168            if !(i8::MIN as i64..=i8::MAX as i64).contains(&v) {
169                return Err(out_of_range("i8"));
170            }
171            Ok((v as i8).to_le_bytes().to_vec())
172        }
173        DType::I16 => {
174            let v = integral_scalar(value, "i16")?;
175            if !(i16::MIN as i64..=i16::MAX as i64).contains(&v) {
176                return Err(out_of_range("i16"));
177            }
178            Ok((v as i16).to_le_bytes().to_vec())
179        }
180        DType::I32 => {
181            let v = integral_scalar(value, "i32")?;
182            if !(i32::MIN as i64..=i32::MAX as i64).contains(&v) {
183                return Err(out_of_range("i32"));
184            }
185            Ok((v as i32).to_le_bytes().to_vec())
186        }
187        DType::I64 => {
188            if !value.is_finite() {
189                return Err(format!(
190                    "constant value {value} is not finite for dtype i64"
191                ));
192            }
193            if value.fract() != 0.0 {
194                return Err(format!(
195                    "constant value {value} must be integral for dtype i64"
196                ));
197            }
198            // `i64::MAX as f64` rounds up to 2^63; use open bounds at ±2^63.
199            if value >= 9.223372036854776e18 || value < -9.223372036854776e18 {
200                return Err(out_of_range("i64"));
201            }
202            Ok((value as i64).to_le_bytes().to_vec())
203        }
204        DType::U8 => {
205            let v = integral_scalar(value, "u8")?;
206            if !(0..=u8::MAX as i64).contains(&v) {
207                return Err(out_of_range("u8"));
208            }
209            Ok((v as u8).to_le_bytes().to_vec())
210        }
211        DType::U32 => {
212            let v = integral_scalar(value, "u32")?;
213            if v < 0 || v > u32::MAX as i64 {
214                return Err(out_of_range("u32"));
215            }
216            Ok((v as u32).to_le_bytes().to_vec())
217        }
218        DType::Bool => Ok(vec![u8::from(value != 0.0)]),
219        DType::F16 | DType::BF16 | DType::C64 => Err(format!(
220            "scalar literal dtype '{dtype:?}' is built via f32 constant + cast"
221        )),
222    }
223}
224
225/// Per-element semantics that don't fit into a flat `DType` enum
226/// (plan #40). Mirrors MAX's `layout/element.mojo` `Element` type:
227/// `DType` says "f8", but two FP8 variants exist (e4m3 and e5m2)
228/// with different range/precision tradeoffs. Saturation policy
229/// (clamp on overflow vs. wrap) is similarly orthogonal.
230///
231/// Today most ops only care about `dtype`; downstream quantization
232/// kernels read `subtype` and `saturating` to pick the right
233/// dequant. Building this in early prevents the "every op grew its
234/// own ad-hoc fp8 flag" mess MAX hit in v1.
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
236pub struct Element {
237    pub dtype: DType,
238    /// Subtype within `dtype` for FP8 variants etc. `Standard`
239    /// for everything else.
240    pub subtype: ElementSubtype,
241    /// Whether arithmetic saturates on overflow (true for the
242    /// quantized accumulator paths) or wraps (default).
243    pub saturating: bool,
244}
245
246#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
247pub enum ElementSubtype {
248    Standard,
249    /// FP8 e4m3 (4 exp bits, 3 mantissa) — lower range, more
250    /// precision; matches NVIDIA's "FNUZ" Hopper format.
251    Fp8E4m3,
252    /// FP8 e5m2 (5 exp bits, 2 mantissa) — wider range, less
253    /// precision; closer to bf16 in dynamic range.
254    Fp8E5m2,
255}
256
257impl Element {
258    pub const fn new(dtype: DType) -> Self {
259        Self {
260            dtype,
261            subtype: ElementSubtype::Standard,
262            saturating: false,
263        }
264    }
265    pub const fn fp8_e4m3() -> Self {
266        Self {
267            dtype: DType::U8,
268            subtype: ElementSubtype::Fp8E4m3,
269            saturating: true,
270        }
271    }
272    pub const fn fp8_e5m2() -> Self {
273        Self {
274            dtype: DType::U8,
275            subtype: ElementSubtype::Fp8E5m2,
276            saturating: true,
277        }
278    }
279    pub const fn saturating(self) -> Self {
280        Self {
281            saturating: true,
282            ..self
283        }
284    }
285}
286
287impl std::fmt::Display for DType {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        match self {
290            Self::F32 => write!(f, "f32"),
291            Self::F16 => write!(f, "f16"),
292            Self::BF16 => write!(f, "bf16"),
293            Self::F64 => write!(f, "f64"),
294            Self::I8 => write!(f, "i8"),
295            Self::I16 => write!(f, "i16"),
296            Self::I32 => write!(f, "i32"),
297            Self::I64 => write!(f, "i64"),
298            Self::U8 => write!(f, "u8"),
299            Self::U32 => write!(f, "u32"),
300            Self::Bool => write!(f, "bool"),
301            Self::C64 => write!(f, "c64"),
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn element_constructors() {
312        let f = Element::new(DType::F32);
313        assert_eq!(f.dtype, DType::F32);
314        assert_eq!(f.subtype, ElementSubtype::Standard);
315        assert!(!f.saturating);
316
317        let e4 = Element::fp8_e4m3();
318        assert_eq!(e4.subtype, ElementSubtype::Fp8E4m3);
319        assert!(e4.saturating);
320        assert_eq!(e4.dtype, DType::U8);
321
322        let s = Element::new(DType::I32).saturating();
323        assert!(s.saturating);
324        assert_eq!(s.dtype, DType::I32);
325    }
326
327    #[test]
328    fn promote_same() {
329        assert_eq!(DType::F32.promote(DType::F32), DType::F32);
330        assert_eq!(DType::I8.promote(DType::I8), DType::I8);
331    }
332
333    #[test]
334    fn promote_int_widening() {
335        assert_eq!(DType::I8.promote(DType::I16), DType::I16);
336        assert_eq!(DType::I32.promote(DType::I64), DType::I64);
337    }
338
339    #[test]
340    fn promote_int_to_float() {
341        assert_eq!(DType::I32.promote(DType::F32), DType::F32);
342        assert_eq!(DType::I64.promote(DType::F32), DType::F64);
343        assert_eq!(DType::I8.promote(DType::F16), DType::F32);
344    }
345
346    #[test]
347    fn promote_f16_bf16_goes_to_f32() {
348        assert_eq!(DType::F16.promote(DType::BF16), DType::F32);
349        assert_eq!(DType::BF16.promote(DType::F16), DType::F32);
350    }
351
352    #[test]
353    fn promote_is_commutative_for_well_defined_pairs() {
354        let pairs = [
355            (DType::F32, DType::F16),
356            (DType::I32, DType::F64),
357            (DType::Bool, DType::I8),
358        ];
359        for (a, b) in pairs {
360            assert_eq!(
361                a.promote(b),
362                b.promote(a),
363                "promote({a},{b}) should equal promote({b},{a})"
364            );
365        }
366    }
367
368    #[test]
369    fn scalar_constant_bytes_round_trips() {
370        assert_eq!(
371            scalar_constant_bytes(2.5, DType::F32).unwrap(),
372            2.5f32.to_le_bytes().to_vec()
373        );
374        assert_eq!(
375            scalar_constant_bytes(-1.0, DType::F64).unwrap(),
376            (-1.0f64).to_le_bytes().to_vec()
377        );
378        assert_eq!(
379            scalar_constant_bytes(7.0, DType::I32).unwrap(),
380            7i32.to_le_bytes()
381        );
382        assert_eq!(scalar_constant_bytes(0.0, DType::Bool).unwrap(), vec![0]);
383        assert_eq!(scalar_constant_bytes(1.0, DType::Bool).unwrap(), vec![1]);
384    }
385
386    #[test]
387    fn scalar_constant_bytes_rejects_out_of_range() {
388        assert!(scalar_constant_bytes(128.0, DType::I8).is_err());
389        assert!(scalar_constant_bytes(-1.0, DType::U32).is_err());
390        assert!(scalar_constant_bytes(9.223372036854776e18, DType::I64).is_err());
391        assert!(scalar_constant_bytes(2.5, DType::I32).is_err());
392    }
393
394    #[test]
395    fn scalar_constant_bytes_rejects_low_precision_direct() {
396        assert!(scalar_constant_bytes(1.0, DType::F16).is_err());
397        assert!(scalar_constant_bytes(1.0, DType::BF16).is_err());
398        assert!(scalar_constant_bytes(1.0, DType::C64).is_err());
399    }
400}