Skip to main content

rlx_compile/
precision.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Precision policy + AutoMixedPrecision rewrite pass.
17//!
18//! The `PrecisionPolicy` is a high-level declarative spec that maps
19//! op kinds to numeric precisions. The `AutoMixedPrecision` pass
20//! consumes a policy and rewrites the graph: updates each node's
21//! shape dtype + inserts Cast nodes at precision boundaries.
22//!
23//! After this pass runs, the IR carries per-node precision info via
24//! `node.shape.dtype`, and the backend just reads it to pick the
25//! right kernel variant. Backends don't need any session-level
26//! precision flag.
27
28use rlx_fusion::pass::Pass;
29use rlx_ir::*;
30use std::collections::HashMap;
31
32/// Which numeric precision to use for an op.
33/// (Subset of DType — only the ones we currently dispatch on.)
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum Precision {
36    F32,
37    F16,
38    BF16,
39}
40
41impl Precision {
42    pub fn dtype(self) -> DType {
43        match self {
44            Precision::F32 => DType::F32,
45            Precision::F16 => DType::F16,
46            Precision::BF16 => DType::BF16,
47        }
48    }
49}
50
51/// Cast configuration carried by ops that emit a typed output.
52///
53/// Inspired by TileKernels' `CastInputConfig` / `CastOutputConfig`: a single
54/// dataclass that flows from the layer down to the kernel selector, so adding
55/// new quantized formats (FP8 e4m3, FP4 e2m1, blocked scaling) becomes a
56/// matter of populating fields rather than threading new flags through call
57/// sites.
58///
59/// Today only `out_dtype` is consulted by backends — the scaling-factor
60/// fields are reserved for future quantization passes (FP8 / blocked SF).
61/// Constructed once by the precision pass and embedded in fused ops.
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub struct CastConfig {
64    /// Destination dtype for the cast (fragment of the output tensor).
65    pub out_dtype: DType,
66    /// Scaling factor block size `(rows, cols)` for blocked quantization.
67    /// `None` means no scaling factor (plain cast).
68    pub sf_block: Option<(usize, usize)>,
69    /// Round scaling factors to powers of two (UE8M0 style).
70    pub round_sf: bool,
71}
72
73impl CastConfig {
74    /// Plain dtype cast with no scaling factor.
75    pub const fn plain(out_dtype: DType) -> Self {
76        Self {
77            out_dtype,
78            sf_block: None,
79            round_sf: false,
80        }
81    }
82    /// True when the cast does no work (out matches input dtype).
83    pub fn is_noop(&self, in_dtype: DType) -> bool {
84        self.out_dtype == in_dtype && self.sf_block.is_none()
85    }
86}
87
88/// High-level op categorization for precision policies.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
90pub enum OpKind {
91    /// Matmul, FusedMatMulBiasAct, conv — compute-heavy ops that
92    /// benefit most from low precision.
93    Compute,
94    /// LayerNorm, RmsNorm, Softmax — reductions that need accuracy.
95    Reduction,
96    /// Add, Mul, GELU, SiLU — element-wise ops.
97    Elementwise,
98    /// Gather, Narrow, Reshape — data movement, no math.
99    DataMovement,
100    /// Inputs, parameters, outputs — user-facing.
101    Boundary,
102}
103
104fn op_kind(op: &Op) -> OpKind {
105    match op {
106        Op::MatMul
107        | Op::FusedMatMulBiasAct { .. }
108        | Op::Conv { .. }
109        | Op::Im2Col { .. }
110        | Op::DotGeneral { .. }
111        | Op::DenseSolve
112        | Op::BatchedDenseSolve
113        | Op::Attention { .. }
114        | Op::FusedTransformerLayer { .. }
115        | Op::GroupedMatMul
116        | Op::DequantGroupedMatMul { .. }
117        | Op::DequantMoEWeights { .. }
118        | Op::LoraMatMul { .. }
119        | Op::DequantMatMul { .. }
120        | Op::QMatMul { .. }
121        | Op::QConv2d { .. }
122        | Op::Conv2dBackwardInput { .. }
123        | Op::Conv2dBackwardWeight { .. }
124        | Op::AttentionBackward { .. } => OpKind::Compute,
125        Op::LayerNorm { .. }
126        | Op::RmsNorm { .. }
127        | Op::Softmax { .. }
128        | Op::FusedResidualLN { .. }
129        | Op::FusedResidualRmsNorm { .. }
130        | Op::Reduce { .. }
131        | Op::Cumsum { .. }
132        | Op::Sample { .. }
133        | Op::SelectiveScan { .. }
134        | Op::GatedDeltaNet { .. }
135        | Op::Lstm { .. }
136        | Op::Gru { .. }
137        | Op::Rnn { .. }
138        | Op::Mamba2 { .. }
139        | Op::SoftmaxCrossEntropyWithLogits
140        | Op::SoftmaxCrossEntropyBackward
141        | Op::LayerNormBackwardInput { .. }
142        | Op::LayerNormBackwardGamma { .. }
143        | Op::GroupNorm { .. } => OpKind::Reduction,
144        Op::Activation(_)
145        | Op::Binary(_)
146        | Op::FusedSwiGLU { .. }
147        | Op::Compare(_)
148        | Op::Where
149        | Op::ElementwiseRegion { .. }
150        | Op::Quantize { .. }
151        | Op::Dequantize { .. }
152        | Op::FakeQuantize { .. }
153        | Op::FakeQuantizeBackward { .. }
154        | Op::FakeQuantizeLSQ { .. }
155        | Op::FakeQuantizeLSQBackwardX { .. }
156        | Op::FakeQuantizeLSQBackwardScale { .. }
157        | Op::ReluBackward
158        | Op::ActivationBackward { .. }
159        | Op::ComplexNormSq
160        | Op::ComplexNormSqBackward
161        | Op::Conjugate => OpKind::Elementwise,
162        Op::Gather { .. }
163        | Op::Narrow { .. }
164        | Op::Reshape { .. }
165        | Op::Transpose { .. }
166        | Op::Concat { .. }
167        | Op::Expand { .. }
168        | Op::Cast { .. }
169        | Op::Rope { .. }
170        | Op::Pool { .. }
171        | Op::FusedAttentionBlock { .. }
172        | Op::TopK { .. }
173        | Op::ScatterAdd
174        | Op::MaxPool2dBackward { .. }
175        | Op::ResizeNearest2x
176        | Op::AxialRope2d { .. } => OpKind::DataMovement,
177        Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => OpKind::Boundary,
178        // Control flow: treated as data movement (the inner sub-graph
179        // gets its own precision policy applied separately).
180        Op::If { .. } | Op::While { .. } => OpKind::DataMovement,
181        // Custom user-registered ops are opaque to the precision pass
182        // — classify as Compute by default; the registered op's own
183        // implementation decides what dtype it operates at.
184        Op::Custom { .. } => OpKind::Compute,
185        Op::Scan { .. } => OpKind::Compute,
186        Op::ScanBackward { .. } => OpKind::Compute,
187        Op::ScanBackwardXs { .. } => OpKind::Compute,
188        Op::CustomFn { .. } => OpKind::Compute,
189        Op::Fft { .. } => OpKind::Compute,
190        Op::FftButterflyStage { .. } => OpKind::Compute,
191        Op::LogMel => OpKind::Compute,
192        Op::LogMelBackward => OpKind::Compute,
193        _ => OpKind::Compute,
194    }
195}
196
197/// Declarative precision policy for graph compilation.
198#[derive(Debug, Clone, Default)]
199pub enum PrecisionPolicy {
200    /// All ops at F32. Default; safe; baseline accuracy.
201    #[default]
202    AlwaysF32,
203    /// All ops at F16. Maximum speed; may lose accuracy on reductions.
204    AlwaysF16,
205    /// Mixed precision, conservative variant. Forces F32 at every reduction
206    /// boundary, matching PyTorch's pre-2024 autocast and HuggingFace's
207    /// historical default. Accuracy is the highest of the AMP variants;
208    /// performance suffers from a Cast node before and after every
209    /// LayerNorm / Softmax in the graph.
210    ///   Compute → F16
211    ///   Reduction → F32  (← the cast tax — see AutoMixed for the fix)
212    ///   Elementwise → F16
213    ///   DataMovement → F16
214    ///   Boundary (input/param/output) → F32
215    AutoMixedConservative,
216    /// Mixed precision (Phase G — current default). Reductions stay in
217    /// the input dtype; the kernels themselves promote-to-f32 internally
218    /// for the accumulation. This eliminates the dozens of Cast nodes
219    /// that AutoMixedConservative inserts at LN/Softmax boundaries
220    /// without sacrificing the f32 reduction accumulation that matters.
221    /// Matches what modern PyTorch autocast actually does on Metal.
222    ///   Compute → F16
223    ///   Reduction → F16  (kernel accumulates in f32 internally)
224    ///   Elementwise → F16
225    ///   DataMovement → F16
226    ///   Boundary (input/param/output) → F32
227    AutoMixed,
228    /// Mixed precision targeting BF16 on TPU/XLA. Same shape as
229    /// `AutoMixed` (compute + reduction + elementwise + data-movement
230    /// in the chosen low precision; boundaries stay F32) but the low
231    /// precision is BF16 instead of F16. BF16 is the native compute
232    /// dtype on TPU and recent GPUs; matches what JAX picks when
233    /// `jax.config.update("jax_default_dtype_bits", "bfloat16")`.
234    ///   Compute → BF16
235    ///   Reduction → BF16  (XLA's TPU codegen accumulates in f32)
236    ///   Elementwise → BF16
237    ///   DataMovement → BF16
238    ///   Boundary → F32
239    AutoMixedBf16,
240    /// Explicit per-op-kind override.
241    Custom(HashMap<OpKind, Precision>),
242}
243
244impl PrecisionPolicy {
245    /// Resolve the target precision for an op kind.
246    pub fn precision_for(&self, kind: OpKind) -> Precision {
247        match self {
248            PrecisionPolicy::AlwaysF32 => Precision::F32,
249            PrecisionPolicy::AlwaysF16 => match kind {
250                OpKind::Boundary => Precision::F32, // user-facing stays f32
251                _ => Precision::F16,
252            },
253            PrecisionPolicy::AutoMixedConservative => match kind {
254                OpKind::Compute => Precision::F16,
255                OpKind::Reduction => Precision::F32,
256                OpKind::Elementwise => Precision::F16,
257                OpKind::DataMovement => Precision::F16,
258                OpKind::Boundary => Precision::F32,
259            },
260            PrecisionPolicy::AutoMixed => match kind {
261                OpKind::Compute => Precision::F16,
262                OpKind::Reduction => Precision::F16,
263                OpKind::Elementwise => Precision::F16,
264                OpKind::DataMovement => Precision::F16,
265                OpKind::Boundary => Precision::F32,
266            },
267            PrecisionPolicy::AutoMixedBf16 => match kind {
268                OpKind::Compute => Precision::BF16,
269                OpKind::Reduction => Precision::BF16,
270                OpKind::Elementwise => Precision::BF16,
271                OpKind::DataMovement => Precision::BF16,
272                OpKind::Boundary => Precision::F32,
273            },
274            PrecisionPolicy::Custom(map) => map.get(&kind).copied().unwrap_or(Precision::F32),
275        }
276    }
277}
278
279/// Pass that rewrites a graph according to a `PrecisionPolicy`.
280///
281/// For each node:
282/// 1. Look up the target precision based on op kind.
283/// 2. Update `node.shape.dtype` to that precision.
284/// 3. If any input has a different dtype, insert a Cast node before it.
285///
286/// After this pass, every node knows its compute precision via its
287/// shape dtype. Backends dispatch kernels per-node.
288pub struct AutoMixedPrecision {
289    pub policy: PrecisionPolicy,
290}
291
292impl AutoMixedPrecision {
293    pub fn new(policy: PrecisionPolicy) -> Self {
294        Self { policy }
295    }
296}
297
298impl Pass for AutoMixedPrecision {
299    fn name(&self) -> &str {
300        "auto_mixed_precision"
301    }
302
303    fn run(&self, graph: Graph) -> Graph {
304        // Skip the pass entirely for AlwaysF32 — it's a no-op.
305        if matches!(self.policy, PrecisionPolicy::AlwaysF32) {
306            return graph;
307        }
308
309        let mut new_graph = Graph::new(&graph.name);
310        // Maps old NodeId → new NodeId at its post-rewrite precision.
311        let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
312        // Tracks the precision each rewritten node ended up at.
313        let mut node_precision: HashMap<NodeId, Precision> = HashMap::new();
314        // Cast cache: avoid re-inserting identical Cast nodes.
315        // Key: (source new id, target precision)
316        let mut cast_cache: HashMap<(NodeId, Precision), NodeId> = HashMap::new();
317
318        for node in graph.nodes() {
319            let kind = op_kind(&node.op);
320            let target = self.policy.precision_for(kind);
321
322            // Inputs / params keep their original dtype (they're external);
323            // outputs stay user-visible at F32.
324            let target = match kind {
325                OpKind::Boundary => Precision::F32,
326                _ => target,
327            };
328
329            // Resolve each input: insert a Cast if precision differs.
330            let mut new_inputs = Vec::with_capacity(node.inputs.len());
331            for &in_id in &node.inputs {
332                let src_new_id = id_map[&in_id];
333                let src_prec = node_precision
334                    .get(&in_id)
335                    .copied()
336                    .unwrap_or(Precision::F32);
337                if src_prec == target {
338                    new_inputs.push(src_new_id);
339                } else {
340                    // Insert (or reuse cached) cast
341                    let cast_id = *cast_cache.entry((src_new_id, target)).or_insert_with(|| {
342                        let shape = new_graph
343                            .node(src_new_id)
344                            .shape
345                            .clone()
346                            .with_dtype(target.dtype());
347                        new_graph.add_node(Op::Cast { to: target.dtype() }, vec![src_new_id], shape)
348                    });
349                    new_inputs.push(cast_id);
350                }
351            }
352
353            // Build the rewritten node with the target dtype on its shape.
354            let new_shape = node.shape.clone().with_dtype(target.dtype());
355            let new_id = new_graph.add_node(node.op.clone(), new_inputs, new_shape);
356            id_map.insert(node.id, new_id);
357            node_precision.insert(node.id, target);
358        }
359
360        // Outputs always stay at F32 — cast back if needed.
361        let new_outputs: Vec<NodeId> = graph
362            .outputs
363            .iter()
364            .map(|&out_id| {
365                let src_new_id = id_map[&out_id];
366                let src_prec = node_precision
367                    .get(&out_id)
368                    .copied()
369                    .unwrap_or(Precision::F32);
370                if src_prec == Precision::F32 {
371                    src_new_id
372                } else {
373                    let shape = new_graph
374                        .node(src_new_id)
375                        .shape
376                        .clone()
377                        .with_dtype(DType::F32);
378                    new_graph.add_node(Op::Cast { to: DType::F32 }, vec![src_new_id], shape)
379                }
380            })
381            .collect();
382        new_graph.set_outputs(new_outputs);
383
384        new_graph
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn always_f32_is_noop() {
394        let mut g = Graph::new("test");
395        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
396        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
397        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
398        g.set_outputs(vec![mm]);
399
400        let pass = AutoMixedPrecision::new(PrecisionPolicy::AlwaysF32);
401        let out = pass.run(g);
402        assert_eq!(out.len(), 3); // input, param, matmul — no casts
403    }
404
405    #[test]
406    fn auto_mixed_inserts_casts_at_boundary() {
407        let mut g = Graph::new("test");
408        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
409        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
410        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
411        g.set_outputs(vec![mm]);
412
413        let pass = AutoMixedPrecision::new(PrecisionPolicy::AutoMixed);
414        let out = pass.run(g);
415
416        // Should have: input(f32), param(f32), cast(f32→f16) for x,
417        // cast(f32→f16) for w, matmul(f16), cast(f16→f32) for output.
418        // = 6 nodes total, with the final output being a Cast back to F32.
419        assert!(out.len() >= 6);
420        let final_node = out.node(out.outputs[0]);
421        assert!(matches!(final_node.op, Op::Cast { to: DType::F32 }));
422    }
423}