Skip to main content

rlx_compile/
precision.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Precision policy + AutoMixedPrecision rewrite pass.
17//!
18//! The `PrecisionPolicy` is a high-level declarative spec that maps
19//! op kinds to numeric precisions. The `AutoMixedPrecision` pass
20//! consumes a policy and rewrites the graph: updates each node's
21//! shape dtype + inserts Cast nodes at precision boundaries.
22//!
23//! After this pass runs, the IR carries per-node precision info via
24//! `node.shape.dtype`, and the backend just reads it to pick the
25//! right kernel variant. Backends don't need any session-level
26//! precision flag.
27
28use rlx_fusion::pass::Pass;
29use rlx_ir::*;
30use std::collections::HashMap;
31
32/// Which numeric precision to use for an op.
33/// (Subset of DType — only the ones we currently dispatch on.)
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum Precision {
36    F32,
37    F16,
38    BF16,
39}
40
41impl Precision {
42    pub fn dtype(self) -> DType {
43        match self {
44            Precision::F32 => DType::F32,
45            Precision::F16 => DType::F16,
46            Precision::BF16 => DType::BF16,
47        }
48    }
49}
50
51/// Cast configuration carried by ops that emit a typed output.
52///
53/// Inspired by TileKernels' `CastInputConfig` / `CastOutputConfig`: a single
54/// dataclass that flows from the layer down to the kernel selector, so adding
55/// new quantized formats (FP8 e4m3, FP4 e2m1, blocked scaling) becomes a
56/// matter of populating fields rather than threading new flags through call
57/// sites.
58///
59/// Today only `out_dtype` is consulted by backends — the scaling-factor
60/// fields are reserved for future quantization passes (FP8 / blocked SF).
61/// Constructed once by the precision pass and embedded in fused ops.
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub struct CastConfig {
64    /// Destination dtype for the cast (fragment of the output tensor).
65    pub out_dtype: DType,
66    /// Scaling factor block size `(rows, cols)` for blocked quantization.
67    /// `None` means no scaling factor (plain cast).
68    pub sf_block: Option<(usize, usize)>,
69    /// Round scaling factors to powers of two (UE8M0 style).
70    pub round_sf: bool,
71}
72
73impl CastConfig {
74    /// Plain dtype cast with no scaling factor.
75    pub const fn plain(out_dtype: DType) -> Self {
76        Self {
77            out_dtype,
78            sf_block: None,
79            round_sf: false,
80        }
81    }
82    /// True when the cast does no work (out matches input dtype).
83    pub fn is_noop(&self, in_dtype: DType) -> bool {
84        self.out_dtype == in_dtype && self.sf_block.is_none()
85    }
86}
87
88/// High-level op categorization for precision policies.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
90pub enum OpKind {
91    /// Matmul, FusedMatMulBiasAct, conv — compute-heavy ops that
92    /// benefit most from low precision.
93    Compute,
94    /// LayerNorm, RmsNorm, Softmax — reductions that need accuracy.
95    Reduction,
96    /// Add, Mul, GELU, SiLU — element-wise ops.
97    Elementwise,
98    /// Gather, Narrow, Reshape — data movement, no math.
99    DataMovement,
100    /// Inputs, parameters, outputs — user-facing.
101    Boundary,
102}
103
104fn op_kind(op: &Op) -> OpKind {
105    match op {
106        Op::MatMul
107        | Op::FusedMatMulBiasAct { .. }
108        | Op::Conv { .. }
109        | Op::DotGeneral { .. }
110        | Op::DenseSolve
111        | Op::BatchedDenseSolve
112        | Op::Attention { .. }
113        | Op::FusedTransformerLayer { .. }
114        | Op::GroupedMatMul
115        | Op::DequantGroupedMatMul { .. }
116        | Op::DequantMoEWeights { .. }
117        | Op::LoraMatMul { .. }
118        | Op::DequantMatMul { .. }
119        | Op::QMatMul { .. }
120        | Op::QConv2d { .. }
121        | Op::Conv2dBackwardInput { .. }
122        | Op::Conv2dBackwardWeight { .. }
123        | Op::AttentionBackward { .. } => OpKind::Compute,
124        Op::LayerNorm { .. }
125        | Op::RmsNorm { .. }
126        | Op::Softmax { .. }
127        | Op::FusedResidualLN { .. }
128        | Op::FusedResidualRmsNorm { .. }
129        | Op::Reduce { .. }
130        | Op::Cumsum { .. }
131        | Op::Sample { .. }
132        | Op::SelectiveScan { .. }
133        | Op::GatedDeltaNet { .. }
134        | Op::SoftmaxCrossEntropyWithLogits
135        | Op::SoftmaxCrossEntropyBackward
136        | Op::LayerNormBackwardInput { .. }
137        | Op::LayerNormBackwardGamma { .. }
138        | Op::GroupNorm { .. } => OpKind::Reduction,
139        Op::Activation(_)
140        | Op::Binary(_)
141        | Op::FusedSwiGLU { .. }
142        | Op::Compare(_)
143        | Op::Where
144        | Op::ElementwiseRegion { .. }
145        | Op::Quantize { .. }
146        | Op::Dequantize { .. }
147        | Op::FakeQuantize { .. }
148        | Op::FakeQuantizeBackward { .. }
149        | Op::FakeQuantizeLSQ { .. }
150        | Op::FakeQuantizeLSQBackwardX { .. }
151        | Op::FakeQuantizeLSQBackwardScale { .. }
152        | Op::ReluBackward
153        | Op::ActivationBackward { .. }
154        | Op::ComplexNormSq
155        | Op::ComplexNormSqBackward
156        | Op::Conjugate => OpKind::Elementwise,
157        Op::Gather { .. }
158        | Op::Narrow { .. }
159        | Op::Reshape { .. }
160        | Op::Transpose { .. }
161        | Op::Concat { .. }
162        | Op::Expand { .. }
163        | Op::Cast { .. }
164        | Op::Rope { .. }
165        | Op::Pool { .. }
166        | Op::FusedAttentionBlock { .. }
167        | Op::TopK { .. }
168        | Op::ScatterAdd
169        | Op::MaxPool2dBackward { .. }
170        | Op::ResizeNearest2x
171        | Op::AxialRope2d { .. } => OpKind::DataMovement,
172        Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => OpKind::Boundary,
173        // Control flow: treated as data movement (the inner sub-graph
174        // gets its own precision policy applied separately).
175        Op::If { .. } | Op::While { .. } => OpKind::DataMovement,
176        // Custom user-registered ops are opaque to the precision pass
177        // — classify as Compute by default; the registered op's own
178        // implementation decides what dtype it operates at.
179        Op::Custom { .. } => OpKind::Compute,
180        Op::Scan { .. } => OpKind::Compute,
181        Op::ScanBackward { .. } => OpKind::Compute,
182        Op::ScanBackwardXs { .. } => OpKind::Compute,
183        Op::CustomFn { .. } => OpKind::Compute,
184        Op::Fft { .. } => OpKind::Compute,
185        _ => OpKind::Compute,
186    }
187}
188
189/// Declarative precision policy for graph compilation.
190#[derive(Debug, Clone, Default)]
191pub enum PrecisionPolicy {
192    /// All ops at F32. Default; safe; baseline accuracy.
193    #[default]
194    AlwaysF32,
195    /// All ops at F16. Maximum speed; may lose accuracy on reductions.
196    AlwaysF16,
197    /// Mixed precision, conservative variant. Forces F32 at every reduction
198    /// boundary, matching PyTorch's pre-2024 autocast and HuggingFace's
199    /// historical default. Accuracy is the highest of the AMP variants;
200    /// performance suffers from a Cast node before and after every
201    /// LayerNorm / Softmax in the graph.
202    ///   Compute → F16
203    ///   Reduction → F32  (← the cast tax — see AutoMixed for the fix)
204    ///   Elementwise → F16
205    ///   DataMovement → F16
206    ///   Boundary (input/param/output) → F32
207    AutoMixedConservative,
208    /// Mixed precision (Phase G — current default). Reductions stay in
209    /// the input dtype; the kernels themselves promote-to-f32 internally
210    /// for the accumulation. This eliminates the dozens of Cast nodes
211    /// that AutoMixedConservative inserts at LN/Softmax boundaries
212    /// without sacrificing the f32 reduction accumulation that matters.
213    /// Matches what modern PyTorch autocast actually does on Metal.
214    ///   Compute → F16
215    ///   Reduction → F16  (kernel accumulates in f32 internally)
216    ///   Elementwise → F16
217    ///   DataMovement → F16
218    ///   Boundary (input/param/output) → F32
219    AutoMixed,
220    /// Mixed precision targeting BF16 on TPU/XLA. Same shape as
221    /// `AutoMixed` (compute + reduction + elementwise + data-movement
222    /// in the chosen low precision; boundaries stay F32) but the low
223    /// precision is BF16 instead of F16. BF16 is the native compute
224    /// dtype on TPU and recent GPUs; matches what JAX picks when
225    /// `jax.config.update("jax_default_dtype_bits", "bfloat16")`.
226    ///   Compute → BF16
227    ///   Reduction → BF16  (XLA's TPU codegen accumulates in f32)
228    ///   Elementwise → BF16
229    ///   DataMovement → BF16
230    ///   Boundary → F32
231    AutoMixedBf16,
232    /// Explicit per-op-kind override.
233    Custom(HashMap<OpKind, Precision>),
234}
235
236impl PrecisionPolicy {
237    /// Resolve the target precision for an op kind.
238    pub fn precision_for(&self, kind: OpKind) -> Precision {
239        match self {
240            PrecisionPolicy::AlwaysF32 => Precision::F32,
241            PrecisionPolicy::AlwaysF16 => match kind {
242                OpKind::Boundary => Precision::F32, // user-facing stays f32
243                _ => Precision::F16,
244            },
245            PrecisionPolicy::AutoMixedConservative => match kind {
246                OpKind::Compute => Precision::F16,
247                OpKind::Reduction => Precision::F32,
248                OpKind::Elementwise => Precision::F16,
249                OpKind::DataMovement => Precision::F16,
250                OpKind::Boundary => Precision::F32,
251            },
252            PrecisionPolicy::AutoMixed => match kind {
253                OpKind::Compute => Precision::F16,
254                OpKind::Reduction => Precision::F16,
255                OpKind::Elementwise => Precision::F16,
256                OpKind::DataMovement => Precision::F16,
257                OpKind::Boundary => Precision::F32,
258            },
259            PrecisionPolicy::AutoMixedBf16 => match kind {
260                OpKind::Compute => Precision::BF16,
261                OpKind::Reduction => Precision::BF16,
262                OpKind::Elementwise => Precision::BF16,
263                OpKind::DataMovement => Precision::BF16,
264                OpKind::Boundary => Precision::F32,
265            },
266            PrecisionPolicy::Custom(map) => map.get(&kind).copied().unwrap_or(Precision::F32),
267        }
268    }
269}
270
271/// Pass that rewrites a graph according to a `PrecisionPolicy`.
272///
273/// For each node:
274/// 1. Look up the target precision based on op kind.
275/// 2. Update `node.shape.dtype` to that precision.
276/// 3. If any input has a different dtype, insert a Cast node before it.
277///
278/// After this pass, every node knows its compute precision via its
279/// shape dtype. Backends dispatch kernels per-node.
280pub struct AutoMixedPrecision {
281    pub policy: PrecisionPolicy,
282}
283
284impl AutoMixedPrecision {
285    pub fn new(policy: PrecisionPolicy) -> Self {
286        Self { policy }
287    }
288}
289
290impl Pass for AutoMixedPrecision {
291    fn name(&self) -> &str {
292        "auto_mixed_precision"
293    }
294
295    fn run(&self, graph: Graph) -> Graph {
296        // Skip the pass entirely for AlwaysF32 — it's a no-op.
297        if matches!(self.policy, PrecisionPolicy::AlwaysF32) {
298            return graph;
299        }
300
301        let mut new_graph = Graph::new(&graph.name);
302        // Maps old NodeId → new NodeId at its post-rewrite precision.
303        let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
304        // Tracks the precision each rewritten node ended up at.
305        let mut node_precision: HashMap<NodeId, Precision> = HashMap::new();
306        // Cast cache: avoid re-inserting identical Cast nodes.
307        // Key: (source new id, target precision)
308        let mut cast_cache: HashMap<(NodeId, Precision), NodeId> = HashMap::new();
309
310        for node in graph.nodes() {
311            let kind = op_kind(&node.op);
312            let target = self.policy.precision_for(kind);
313
314            // Inputs / params keep their original dtype (they're external);
315            // outputs stay user-visible at F32.
316            let target = match kind {
317                OpKind::Boundary => Precision::F32,
318                _ => target,
319            };
320
321            // Resolve each input: insert a Cast if precision differs.
322            let mut new_inputs = Vec::with_capacity(node.inputs.len());
323            for &in_id in &node.inputs {
324                let src_new_id = id_map[&in_id];
325                let src_prec = node_precision
326                    .get(&in_id)
327                    .copied()
328                    .unwrap_or(Precision::F32);
329                if src_prec == target {
330                    new_inputs.push(src_new_id);
331                } else {
332                    // Insert (or reuse cached) cast
333                    let cast_id = *cast_cache.entry((src_new_id, target)).or_insert_with(|| {
334                        let shape = new_graph
335                            .node(src_new_id)
336                            .shape
337                            .clone()
338                            .with_dtype(target.dtype());
339                        new_graph.add_node(Op::Cast { to: target.dtype() }, vec![src_new_id], shape)
340                    });
341                    new_inputs.push(cast_id);
342                }
343            }
344
345            // Build the rewritten node with the target dtype on its shape.
346            let new_shape = node.shape.clone().with_dtype(target.dtype());
347            let new_id = new_graph.add_node(node.op.clone(), new_inputs, new_shape);
348            id_map.insert(node.id, new_id);
349            node_precision.insert(node.id, target);
350        }
351
352        // Outputs always stay at F32 — cast back if needed.
353        let new_outputs: Vec<NodeId> = graph
354            .outputs
355            .iter()
356            .map(|&out_id| {
357                let src_new_id = id_map[&out_id];
358                let src_prec = node_precision
359                    .get(&out_id)
360                    .copied()
361                    .unwrap_or(Precision::F32);
362                if src_prec == Precision::F32 {
363                    src_new_id
364                } else {
365                    let shape = new_graph
366                        .node(src_new_id)
367                        .shape
368                        .clone()
369                        .with_dtype(DType::F32);
370                    new_graph.add_node(Op::Cast { to: DType::F32 }, vec![src_new_id], shape)
371                }
372            })
373            .collect();
374        new_graph.set_outputs(new_outputs);
375
376        new_graph
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn always_f32_is_noop() {
386        let mut g = Graph::new("test");
387        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
388        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
389        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
390        g.set_outputs(vec![mm]);
391
392        let pass = AutoMixedPrecision::new(PrecisionPolicy::AlwaysF32);
393        let out = pass.run(g);
394        assert_eq!(out.len(), 3); // input, param, matmul — no casts
395    }
396
397    #[test]
398    fn auto_mixed_inserts_casts_at_boundary() {
399        let mut g = Graph::new("test");
400        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
401        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
402        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
403        g.set_outputs(vec![mm]);
404
405        let pass = AutoMixedPrecision::new(PrecisionPolicy::AutoMixed);
406        let out = pass.run(g);
407
408        // Should have: input(f32), param(f32), cast(f32→f16) for x,
409        // cast(f32→f16) for w, matmul(f16), cast(f16→f32) for output.
410        // = 6 nodes total, with the final output being a Cast back to F32.
411        assert!(out.len() >= 6);
412        let final_node = out.node(out.outputs[0]);
413        assert!(matches!(final_node.op, Op::Cast { to: DType::F32 }));
414    }
415}