rlx_compile/precision.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Precision policy + AutoMixedPrecision rewrite pass.
17//!
18//! The `PrecisionPolicy` is a high-level declarative spec that maps
19//! op kinds to numeric precisions. The `AutoMixedPrecision` pass
20//! consumes a policy and rewrites the graph: updates each node's
21//! shape dtype + inserts Cast nodes at precision boundaries.
22//!
23//! After this pass runs, the IR carries per-node precision info via
24//! `node.shape.dtype`, and the backend just reads it to pick the
25//! right kernel variant. Backends don't need any session-level
26//! precision flag.
27
28use rlx_fusion::pass::Pass;
29use rlx_ir::*;
30use std::collections::HashMap;
31
32/// Which numeric precision to use for an op.
33/// (Subset of DType — only the ones we currently dispatch on.)
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum Precision {
36 F32,
37 F16,
38 BF16,
39}
40
41impl Precision {
42 pub fn dtype(self) -> DType {
43 match self {
44 Precision::F32 => DType::F32,
45 Precision::F16 => DType::F16,
46 Precision::BF16 => DType::BF16,
47 }
48 }
49}
50
51/// Cast configuration carried by ops that emit a typed output.
52///
53/// Inspired by TileKernels' `CastInputConfig` / `CastOutputConfig`: a single
54/// dataclass that flows from the layer down to the kernel selector, so adding
55/// new quantized formats (FP8 e4m3, FP4 e2m1, blocked scaling) becomes a
56/// matter of populating fields rather than threading new flags through call
57/// sites.
58///
59/// Today only `out_dtype` is consulted by backends — the scaling-factor
60/// fields are reserved for future quantization passes (FP8 / blocked SF).
61/// Constructed once by the precision pass and embedded in fused ops.
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub struct CastConfig {
64 /// Destination dtype for the cast (fragment of the output tensor).
65 pub out_dtype: DType,
66 /// Scaling factor block size `(rows, cols)` for blocked quantization.
67 /// `None` means no scaling factor (plain cast).
68 pub sf_block: Option<(usize, usize)>,
69 /// Round scaling factors to powers of two (UE8M0 style).
70 pub round_sf: bool,
71}
72
73impl CastConfig {
74 /// Plain dtype cast with no scaling factor.
75 pub const fn plain(out_dtype: DType) -> Self {
76 Self {
77 out_dtype,
78 sf_block: None,
79 round_sf: false,
80 }
81 }
82 /// True when the cast does no work (out matches input dtype).
83 pub fn is_noop(&self, in_dtype: DType) -> bool {
84 self.out_dtype == in_dtype && self.sf_block.is_none()
85 }
86}
87
88/// High-level op categorization for precision policies.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
90pub enum OpKind {
91 /// Matmul, FusedMatMulBiasAct, conv — compute-heavy ops that
92 /// benefit most from low precision.
93 Compute,
94 /// LayerNorm, RmsNorm, Softmax — reductions that need accuracy.
95 Reduction,
96 /// Add, Mul, GELU, SiLU — element-wise ops.
97 Elementwise,
98 /// Gather, Narrow, Reshape — data movement, no math.
99 DataMovement,
100 /// Inputs, parameters, outputs — user-facing.
101 Boundary,
102}
103
104fn op_kind(op: &Op) -> OpKind {
105 match op {
106 Op::MatMul
107 | Op::FusedMatMulBiasAct { .. }
108 | Op::Conv { .. }
109 | Op::DotGeneral { .. }
110 | Op::DenseSolve
111 | Op::BatchedDenseSolve
112 | Op::Attention { .. }
113 | Op::FusedTransformerLayer { .. }
114 | Op::GroupedMatMul
115 | Op::DequantGroupedMatMul { .. }
116 | Op::DequantMoEWeights { .. }
117 | Op::LoraMatMul { .. }
118 | Op::DequantMatMul { .. }
119 | Op::QMatMul { .. }
120 | Op::QConv2d { .. }
121 | Op::Conv2dBackwardInput { .. }
122 | Op::Conv2dBackwardWeight { .. }
123 | Op::AttentionBackward { .. } => OpKind::Compute,
124 Op::LayerNorm { .. }
125 | Op::RmsNorm { .. }
126 | Op::Softmax { .. }
127 | Op::FusedResidualLN { .. }
128 | Op::FusedResidualRmsNorm { .. }
129 | Op::Reduce { .. }
130 | Op::Cumsum { .. }
131 | Op::Sample { .. }
132 | Op::SelectiveScan { .. }
133 | Op::GatedDeltaNet { .. }
134 | Op::SoftmaxCrossEntropyWithLogits
135 | Op::SoftmaxCrossEntropyBackward
136 | Op::LayerNormBackwardInput { .. }
137 | Op::LayerNormBackwardGamma { .. }
138 | Op::GroupNorm { .. } => OpKind::Reduction,
139 Op::Activation(_)
140 | Op::Binary(_)
141 | Op::FusedSwiGLU { .. }
142 | Op::Compare(_)
143 | Op::Where
144 | Op::ElementwiseRegion { .. }
145 | Op::Quantize { .. }
146 | Op::Dequantize { .. }
147 | Op::FakeQuantize { .. }
148 | Op::FakeQuantizeBackward { .. }
149 | Op::FakeQuantizeLSQ { .. }
150 | Op::FakeQuantizeLSQBackwardX { .. }
151 | Op::FakeQuantizeLSQBackwardScale { .. }
152 | Op::ReluBackward
153 | Op::ActivationBackward { .. }
154 | Op::ComplexNormSq
155 | Op::ComplexNormSqBackward
156 | Op::Conjugate => OpKind::Elementwise,
157 Op::Gather { .. }
158 | Op::Narrow { .. }
159 | Op::Reshape { .. }
160 | Op::Transpose { .. }
161 | Op::Concat { .. }
162 | Op::Expand { .. }
163 | Op::Cast { .. }
164 | Op::Rope { .. }
165 | Op::Pool { .. }
166 | Op::FusedAttentionBlock { .. }
167 | Op::TopK { .. }
168 | Op::ScatterAdd
169 | Op::MaxPool2dBackward { .. }
170 | Op::ResizeNearest2x
171 | Op::AxialRope2d { .. } => OpKind::DataMovement,
172 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => OpKind::Boundary,
173 // Control flow: treated as data movement (the inner sub-graph
174 // gets its own precision policy applied separately).
175 Op::If { .. } | Op::While { .. } => OpKind::DataMovement,
176 // Custom user-registered ops are opaque to the precision pass
177 // — classify as Compute by default; the registered op's own
178 // implementation decides what dtype it operates at.
179 Op::Custom { .. } => OpKind::Compute,
180 Op::Scan { .. } => OpKind::Compute,
181 Op::ScanBackward { .. } => OpKind::Compute,
182 Op::ScanBackwardXs { .. } => OpKind::Compute,
183 Op::CustomFn { .. } => OpKind::Compute,
184 Op::Fft { .. } => OpKind::Compute,
185 _ => OpKind::Compute,
186 }
187}
188
189/// Declarative precision policy for graph compilation.
190#[derive(Debug, Clone, Default)]
191pub enum PrecisionPolicy {
192 /// All ops at F32. Default; safe; baseline accuracy.
193 #[default]
194 AlwaysF32,
195 /// All ops at F16. Maximum speed; may lose accuracy on reductions.
196 AlwaysF16,
197 /// Mixed precision, conservative variant. Forces F32 at every reduction
198 /// boundary, matching PyTorch's pre-2024 autocast and HuggingFace's
199 /// historical default. Accuracy is the highest of the AMP variants;
200 /// performance suffers from a Cast node before and after every
201 /// LayerNorm / Softmax in the graph.
202 /// Compute → F16
203 /// Reduction → F32 (← the cast tax — see AutoMixed for the fix)
204 /// Elementwise → F16
205 /// DataMovement → F16
206 /// Boundary (input/param/output) → F32
207 AutoMixedConservative,
208 /// Mixed precision (Phase G — current default). Reductions stay in
209 /// the input dtype; the kernels themselves promote-to-f32 internally
210 /// for the accumulation. This eliminates the dozens of Cast nodes
211 /// that AutoMixedConservative inserts at LN/Softmax boundaries
212 /// without sacrificing the f32 reduction accumulation that matters.
213 /// Matches what modern PyTorch autocast actually does on Metal.
214 /// Compute → F16
215 /// Reduction → F16 (kernel accumulates in f32 internally)
216 /// Elementwise → F16
217 /// DataMovement → F16
218 /// Boundary (input/param/output) → F32
219 AutoMixed,
220 /// Mixed precision targeting BF16 on TPU/XLA. Same shape as
221 /// `AutoMixed` (compute + reduction + elementwise + data-movement
222 /// in the chosen low precision; boundaries stay F32) but the low
223 /// precision is BF16 instead of F16. BF16 is the native compute
224 /// dtype on TPU and recent GPUs; matches what JAX picks when
225 /// `jax.config.update("jax_default_dtype_bits", "bfloat16")`.
226 /// Compute → BF16
227 /// Reduction → BF16 (XLA's TPU codegen accumulates in f32)
228 /// Elementwise → BF16
229 /// DataMovement → BF16
230 /// Boundary → F32
231 AutoMixedBf16,
232 /// Explicit per-op-kind override.
233 Custom(HashMap<OpKind, Precision>),
234}
235
236impl PrecisionPolicy {
237 /// Resolve the target precision for an op kind.
238 pub fn precision_for(&self, kind: OpKind) -> Precision {
239 match self {
240 PrecisionPolicy::AlwaysF32 => Precision::F32,
241 PrecisionPolicy::AlwaysF16 => match kind {
242 OpKind::Boundary => Precision::F32, // user-facing stays f32
243 _ => Precision::F16,
244 },
245 PrecisionPolicy::AutoMixedConservative => match kind {
246 OpKind::Compute => Precision::F16,
247 OpKind::Reduction => Precision::F32,
248 OpKind::Elementwise => Precision::F16,
249 OpKind::DataMovement => Precision::F16,
250 OpKind::Boundary => Precision::F32,
251 },
252 PrecisionPolicy::AutoMixed => match kind {
253 OpKind::Compute => Precision::F16,
254 OpKind::Reduction => Precision::F16,
255 OpKind::Elementwise => Precision::F16,
256 OpKind::DataMovement => Precision::F16,
257 OpKind::Boundary => Precision::F32,
258 },
259 PrecisionPolicy::AutoMixedBf16 => match kind {
260 OpKind::Compute => Precision::BF16,
261 OpKind::Reduction => Precision::BF16,
262 OpKind::Elementwise => Precision::BF16,
263 OpKind::DataMovement => Precision::BF16,
264 OpKind::Boundary => Precision::F32,
265 },
266 PrecisionPolicy::Custom(map) => map.get(&kind).copied().unwrap_or(Precision::F32),
267 }
268 }
269}
270
271/// Pass that rewrites a graph according to a `PrecisionPolicy`.
272///
273/// For each node:
274/// 1. Look up the target precision based on op kind.
275/// 2. Update `node.shape.dtype` to that precision.
276/// 3. If any input has a different dtype, insert a Cast node before it.
277///
278/// After this pass, every node knows its compute precision via its
279/// shape dtype. Backends dispatch kernels per-node.
280pub struct AutoMixedPrecision {
281 pub policy: PrecisionPolicy,
282}
283
284impl AutoMixedPrecision {
285 pub fn new(policy: PrecisionPolicy) -> Self {
286 Self { policy }
287 }
288}
289
290impl Pass for AutoMixedPrecision {
291 fn name(&self) -> &str {
292 "auto_mixed_precision"
293 }
294
295 fn run(&self, graph: Graph) -> Graph {
296 // Skip the pass entirely for AlwaysF32 — it's a no-op.
297 if matches!(self.policy, PrecisionPolicy::AlwaysF32) {
298 return graph;
299 }
300
301 let mut new_graph = Graph::new(&graph.name);
302 // Maps old NodeId → new NodeId at its post-rewrite precision.
303 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
304 // Tracks the precision each rewritten node ended up at.
305 let mut node_precision: HashMap<NodeId, Precision> = HashMap::new();
306 // Cast cache: avoid re-inserting identical Cast nodes.
307 // Key: (source new id, target precision)
308 let mut cast_cache: HashMap<(NodeId, Precision), NodeId> = HashMap::new();
309
310 for node in graph.nodes() {
311 let kind = op_kind(&node.op);
312 let target = self.policy.precision_for(kind);
313
314 // Inputs / params keep their original dtype (they're external);
315 // outputs stay user-visible at F32.
316 let target = match kind {
317 OpKind::Boundary => Precision::F32,
318 _ => target,
319 };
320
321 // Resolve each input: insert a Cast if precision differs.
322 let mut new_inputs = Vec::with_capacity(node.inputs.len());
323 for &in_id in &node.inputs {
324 let src_new_id = id_map[&in_id];
325 let src_prec = node_precision
326 .get(&in_id)
327 .copied()
328 .unwrap_or(Precision::F32);
329 if src_prec == target {
330 new_inputs.push(src_new_id);
331 } else {
332 // Insert (or reuse cached) cast
333 let cast_id = *cast_cache.entry((src_new_id, target)).or_insert_with(|| {
334 let shape = new_graph
335 .node(src_new_id)
336 .shape
337 .clone()
338 .with_dtype(target.dtype());
339 new_graph.add_node(Op::Cast { to: target.dtype() }, vec![src_new_id], shape)
340 });
341 new_inputs.push(cast_id);
342 }
343 }
344
345 // Build the rewritten node with the target dtype on its shape.
346 let new_shape = node.shape.clone().with_dtype(target.dtype());
347 let new_id = new_graph.add_node(node.op.clone(), new_inputs, new_shape);
348 id_map.insert(node.id, new_id);
349 node_precision.insert(node.id, target);
350 }
351
352 // Outputs always stay at F32 — cast back if needed.
353 let new_outputs: Vec<NodeId> = graph
354 .outputs
355 .iter()
356 .map(|&out_id| {
357 let src_new_id = id_map[&out_id];
358 let src_prec = node_precision
359 .get(&out_id)
360 .copied()
361 .unwrap_or(Precision::F32);
362 if src_prec == Precision::F32 {
363 src_new_id
364 } else {
365 let shape = new_graph
366 .node(src_new_id)
367 .shape
368 .clone()
369 .with_dtype(DType::F32);
370 new_graph.add_node(Op::Cast { to: DType::F32 }, vec![src_new_id], shape)
371 }
372 })
373 .collect();
374 new_graph.set_outputs(new_outputs);
375
376 new_graph
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn always_f32_is_noop() {
386 let mut g = Graph::new("test");
387 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
388 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
389 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
390 g.set_outputs(vec![mm]);
391
392 let pass = AutoMixedPrecision::new(PrecisionPolicy::AlwaysF32);
393 let out = pass.run(g);
394 assert_eq!(out.len(), 3); // input, param, matmul — no casts
395 }
396
397 #[test]
398 fn auto_mixed_inserts_casts_at_boundary() {
399 let mut g = Graph::new("test");
400 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
401 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
402 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
403 g.set_outputs(vec![mm]);
404
405 let pass = AutoMixedPrecision::new(PrecisionPolicy::AutoMixed);
406 let out = pass.run(g);
407
408 // Should have: input(f32), param(f32), cast(f32→f16) for x,
409 // cast(f32→f16) for w, matmul(f16), cast(f16→f32) for output.
410 // = 6 nodes total, with the final output being a Cast back to F32.
411 assert!(out.len() >= 6);
412 let final_node = out.node(out.outputs[0]);
413 assert!(matches!(final_node.op, Op::Cast { to: DType::F32 }));
414 }
415}