rlx_compile/precision.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Precision policy + AutoMixedPrecision rewrite pass.
17//!
18//! The `PrecisionPolicy` is a high-level declarative spec that maps
19//! op kinds to numeric precisions. The `AutoMixedPrecision` pass
20//! consumes a policy and rewrites the graph: updates each node's
21//! shape dtype + inserts Cast nodes at precision boundaries.
22//!
23//! After this pass runs, the IR carries per-node precision info via
24//! `node.shape.dtype`, and the backend just reads it to pick the
25//! right kernel variant. Backends don't need any session-level
26//! precision flag.
27
28use rlx_fusion::pass::Pass;
29use rlx_ir::*;
30use std::collections::HashMap;
31
32/// Which numeric precision to use for an op.
33/// (Subset of DType — only the ones we currently dispatch on.)
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum Precision {
36 F32,
37 F16,
38 BF16,
39}
40
41impl Precision {
42 pub fn dtype(self) -> DType {
43 match self {
44 Precision::F32 => DType::F32,
45 Precision::F16 => DType::F16,
46 Precision::BF16 => DType::BF16,
47 }
48 }
49}
50
51/// Cast configuration carried by ops that emit a typed output.
52///
53/// Inspired by TileKernels' `CastInputConfig` / `CastOutputConfig`: a single
54/// dataclass that flows from the layer down to the kernel selector, so adding
55/// new quantized formats (FP8 e4m3, FP4 e2m1, blocked scaling) becomes a
56/// matter of populating fields rather than threading new flags through call
57/// sites.
58///
59/// Today only `out_dtype` is consulted by backends — the scaling-factor
60/// fields are reserved for future quantization passes (FP8 / blocked SF).
61/// Constructed once by the precision pass and embedded in fused ops.
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub struct CastConfig {
64 /// Destination dtype for the cast (fragment of the output tensor).
65 pub out_dtype: DType,
66 /// Scaling factor block size `(rows, cols)` for blocked quantization.
67 /// `None` means no scaling factor (plain cast).
68 pub sf_block: Option<(usize, usize)>,
69 /// Round scaling factors to powers of two (UE8M0 style).
70 pub round_sf: bool,
71}
72
73impl CastConfig {
74 /// Plain dtype cast with no scaling factor.
75 pub const fn plain(out_dtype: DType) -> Self {
76 Self {
77 out_dtype,
78 sf_block: None,
79 round_sf: false,
80 }
81 }
82 /// True when the cast does no work (out matches input dtype).
83 pub fn is_noop(&self, in_dtype: DType) -> bool {
84 self.out_dtype == in_dtype && self.sf_block.is_none()
85 }
86}
87
88/// High-level op categorization for precision policies.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
90pub enum OpKind {
91 /// Matmul, FusedMatMulBiasAct, conv — compute-heavy ops that
92 /// benefit most from low precision.
93 Compute,
94 /// LayerNorm, RmsNorm, Softmax — reductions that need accuracy.
95 Reduction,
96 /// Add, Mul, GELU, SiLU — element-wise ops.
97 Elementwise,
98 /// Gather, Narrow, Reshape — data movement, no math.
99 DataMovement,
100 /// Inputs, parameters, outputs — user-facing.
101 Boundary,
102}
103
104fn op_kind(op: &Op) -> OpKind {
105 match op {
106 Op::MatMul
107 | Op::FusedMatMulBiasAct { .. }
108 | Op::Conv { .. }
109 | Op::Im2Col { .. }
110 | Op::DotGeneral { .. }
111 | Op::DenseSolve
112 | Op::BatchedDenseSolve
113 | Op::Attention { .. }
114 | Op::FusedTransformerLayer { .. }
115 | Op::GroupedMatMul
116 | Op::DequantGroupedMatMul { .. }
117 | Op::DequantMoEWeights { .. }
118 | Op::LoraMatMul { .. }
119 | Op::DequantMatMul { .. }
120 | Op::QMatMul { .. }
121 | Op::QConv2d { .. }
122 | Op::Conv2dBackwardInput { .. }
123 | Op::Conv2dBackwardWeight { .. }
124 | Op::AttentionBackward { .. } => OpKind::Compute,
125 Op::LayerNorm { .. }
126 | Op::RmsNorm { .. }
127 | Op::Softmax { .. }
128 | Op::FusedResidualLN { .. }
129 | Op::FusedResidualRmsNorm { .. }
130 | Op::Reduce { .. }
131 | Op::Cumsum { .. }
132 | Op::Sample { .. }
133 | Op::SelectiveScan { .. }
134 | Op::GatedDeltaNet { .. }
135 | Op::SoftmaxCrossEntropyWithLogits
136 | Op::SoftmaxCrossEntropyBackward
137 | Op::LayerNormBackwardInput { .. }
138 | Op::LayerNormBackwardGamma { .. }
139 | Op::GroupNorm { .. } => OpKind::Reduction,
140 Op::Activation(_)
141 | Op::Binary(_)
142 | Op::FusedSwiGLU { .. }
143 | Op::Compare(_)
144 | Op::Where
145 | Op::ElementwiseRegion { .. }
146 | Op::Quantize { .. }
147 | Op::Dequantize { .. }
148 | Op::FakeQuantize { .. }
149 | Op::FakeQuantizeBackward { .. }
150 | Op::FakeQuantizeLSQ { .. }
151 | Op::FakeQuantizeLSQBackwardX { .. }
152 | Op::FakeQuantizeLSQBackwardScale { .. }
153 | Op::ReluBackward
154 | Op::ActivationBackward { .. }
155 | Op::ComplexNormSq
156 | Op::ComplexNormSqBackward
157 | Op::Conjugate => OpKind::Elementwise,
158 Op::Gather { .. }
159 | Op::Narrow { .. }
160 | Op::Reshape { .. }
161 | Op::Transpose { .. }
162 | Op::Concat { .. }
163 | Op::Expand { .. }
164 | Op::Cast { .. }
165 | Op::Rope { .. }
166 | Op::Pool { .. }
167 | Op::FusedAttentionBlock { .. }
168 | Op::TopK { .. }
169 | Op::ScatterAdd
170 | Op::MaxPool2dBackward { .. }
171 | Op::ResizeNearest2x
172 | Op::AxialRope2d { .. } => OpKind::DataMovement,
173 Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => OpKind::Boundary,
174 // Control flow: treated as data movement (the inner sub-graph
175 // gets its own precision policy applied separately).
176 Op::If { .. } | Op::While { .. } => OpKind::DataMovement,
177 // Custom user-registered ops are opaque to the precision pass
178 // — classify as Compute by default; the registered op's own
179 // implementation decides what dtype it operates at.
180 Op::Custom { .. } => OpKind::Compute,
181 Op::Scan { .. } => OpKind::Compute,
182 Op::ScanBackward { .. } => OpKind::Compute,
183 Op::ScanBackwardXs { .. } => OpKind::Compute,
184 Op::CustomFn { .. } => OpKind::Compute,
185 Op::Fft { .. } => OpKind::Compute,
186 Op::FftButterflyStage { .. } => OpKind::Compute,
187 Op::LogMel => OpKind::Compute,
188 Op::LogMelBackward => OpKind::Compute,
189 _ => OpKind::Compute,
190 }
191}
192
193/// Declarative precision policy for graph compilation.
194#[derive(Debug, Clone, Default)]
195pub enum PrecisionPolicy {
196 /// All ops at F32. Default; safe; baseline accuracy.
197 #[default]
198 AlwaysF32,
199 /// All ops at F16. Maximum speed; may lose accuracy on reductions.
200 AlwaysF16,
201 /// Mixed precision, conservative variant. Forces F32 at every reduction
202 /// boundary, matching PyTorch's pre-2024 autocast and HuggingFace's
203 /// historical default. Accuracy is the highest of the AMP variants;
204 /// performance suffers from a Cast node before and after every
205 /// LayerNorm / Softmax in the graph.
206 /// Compute → F16
207 /// Reduction → F32 (← the cast tax — see AutoMixed for the fix)
208 /// Elementwise → F16
209 /// DataMovement → F16
210 /// Boundary (input/param/output) → F32
211 AutoMixedConservative,
212 /// Mixed precision (Phase G — current default). Reductions stay in
213 /// the input dtype; the kernels themselves promote-to-f32 internally
214 /// for the accumulation. This eliminates the dozens of Cast nodes
215 /// that AutoMixedConservative inserts at LN/Softmax boundaries
216 /// without sacrificing the f32 reduction accumulation that matters.
217 /// Matches what modern PyTorch autocast actually does on Metal.
218 /// Compute → F16
219 /// Reduction → F16 (kernel accumulates in f32 internally)
220 /// Elementwise → F16
221 /// DataMovement → F16
222 /// Boundary (input/param/output) → F32
223 AutoMixed,
224 /// Mixed precision targeting BF16 on TPU/XLA. Same shape as
225 /// `AutoMixed` (compute + reduction + elementwise + data-movement
226 /// in the chosen low precision; boundaries stay F32) but the low
227 /// precision is BF16 instead of F16. BF16 is the native compute
228 /// dtype on TPU and recent GPUs; matches what JAX picks when
229 /// `jax.config.update("jax_default_dtype_bits", "bfloat16")`.
230 /// Compute → BF16
231 /// Reduction → BF16 (XLA's TPU codegen accumulates in f32)
232 /// Elementwise → BF16
233 /// DataMovement → BF16
234 /// Boundary → F32
235 AutoMixedBf16,
236 /// Explicit per-op-kind override.
237 Custom(HashMap<OpKind, Precision>),
238}
239
240impl PrecisionPolicy {
241 /// Resolve the target precision for an op kind.
242 pub fn precision_for(&self, kind: OpKind) -> Precision {
243 match self {
244 PrecisionPolicy::AlwaysF32 => Precision::F32,
245 PrecisionPolicy::AlwaysF16 => match kind {
246 OpKind::Boundary => Precision::F32, // user-facing stays f32
247 _ => Precision::F16,
248 },
249 PrecisionPolicy::AutoMixedConservative => match kind {
250 OpKind::Compute => Precision::F16,
251 OpKind::Reduction => Precision::F32,
252 OpKind::Elementwise => Precision::F16,
253 OpKind::DataMovement => Precision::F16,
254 OpKind::Boundary => Precision::F32,
255 },
256 PrecisionPolicy::AutoMixed => match kind {
257 OpKind::Compute => Precision::F16,
258 OpKind::Reduction => Precision::F16,
259 OpKind::Elementwise => Precision::F16,
260 OpKind::DataMovement => Precision::F16,
261 OpKind::Boundary => Precision::F32,
262 },
263 PrecisionPolicy::AutoMixedBf16 => match kind {
264 OpKind::Compute => Precision::BF16,
265 OpKind::Reduction => Precision::BF16,
266 OpKind::Elementwise => Precision::BF16,
267 OpKind::DataMovement => Precision::BF16,
268 OpKind::Boundary => Precision::F32,
269 },
270 PrecisionPolicy::Custom(map) => map.get(&kind).copied().unwrap_or(Precision::F32),
271 }
272 }
273}
274
275/// Pass that rewrites a graph according to a `PrecisionPolicy`.
276///
277/// For each node:
278/// 1. Look up the target precision based on op kind.
279/// 2. Update `node.shape.dtype` to that precision.
280/// 3. If any input has a different dtype, insert a Cast node before it.
281///
282/// After this pass, every node knows its compute precision via its
283/// shape dtype. Backends dispatch kernels per-node.
284pub struct AutoMixedPrecision {
285 pub policy: PrecisionPolicy,
286}
287
288impl AutoMixedPrecision {
289 pub fn new(policy: PrecisionPolicy) -> Self {
290 Self { policy }
291 }
292}
293
294impl Pass for AutoMixedPrecision {
295 fn name(&self) -> &str {
296 "auto_mixed_precision"
297 }
298
299 fn run(&self, graph: Graph) -> Graph {
300 // Skip the pass entirely for AlwaysF32 — it's a no-op.
301 if matches!(self.policy, PrecisionPolicy::AlwaysF32) {
302 return graph;
303 }
304
305 let mut new_graph = Graph::new(&graph.name);
306 // Maps old NodeId → new NodeId at its post-rewrite precision.
307 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
308 // Tracks the precision each rewritten node ended up at.
309 let mut node_precision: HashMap<NodeId, Precision> = HashMap::new();
310 // Cast cache: avoid re-inserting identical Cast nodes.
311 // Key: (source new id, target precision)
312 let mut cast_cache: HashMap<(NodeId, Precision), NodeId> = HashMap::new();
313
314 for node in graph.nodes() {
315 let kind = op_kind(&node.op);
316 let target = self.policy.precision_for(kind);
317
318 // Inputs / params keep their original dtype (they're external);
319 // outputs stay user-visible at F32.
320 let target = match kind {
321 OpKind::Boundary => Precision::F32,
322 _ => target,
323 };
324
325 // Resolve each input: insert a Cast if precision differs.
326 let mut new_inputs = Vec::with_capacity(node.inputs.len());
327 for &in_id in &node.inputs {
328 let src_new_id = id_map[&in_id];
329 let src_prec = node_precision
330 .get(&in_id)
331 .copied()
332 .unwrap_or(Precision::F32);
333 if src_prec == target {
334 new_inputs.push(src_new_id);
335 } else {
336 // Insert (or reuse cached) cast
337 let cast_id = *cast_cache.entry((src_new_id, target)).or_insert_with(|| {
338 let shape = new_graph
339 .node(src_new_id)
340 .shape
341 .clone()
342 .with_dtype(target.dtype());
343 new_graph.add_node(Op::Cast { to: target.dtype() }, vec![src_new_id], shape)
344 });
345 new_inputs.push(cast_id);
346 }
347 }
348
349 // Build the rewritten node with the target dtype on its shape.
350 let new_shape = node.shape.clone().with_dtype(target.dtype());
351 let new_id = new_graph.add_node(node.op.clone(), new_inputs, new_shape);
352 id_map.insert(node.id, new_id);
353 node_precision.insert(node.id, target);
354 }
355
356 // Outputs always stay at F32 — cast back if needed.
357 let new_outputs: Vec<NodeId> = graph
358 .outputs
359 .iter()
360 .map(|&out_id| {
361 let src_new_id = id_map[&out_id];
362 let src_prec = node_precision
363 .get(&out_id)
364 .copied()
365 .unwrap_or(Precision::F32);
366 if src_prec == Precision::F32 {
367 src_new_id
368 } else {
369 let shape = new_graph
370 .node(src_new_id)
371 .shape
372 .clone()
373 .with_dtype(DType::F32);
374 new_graph.add_node(Op::Cast { to: DType::F32 }, vec![src_new_id], shape)
375 }
376 })
377 .collect();
378 new_graph.set_outputs(new_outputs);
379
380 new_graph
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn always_f32_is_noop() {
390 let mut g = Graph::new("test");
391 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
392 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
393 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
394 g.set_outputs(vec![mm]);
395
396 let pass = AutoMixedPrecision::new(PrecisionPolicy::AlwaysF32);
397 let out = pass.run(g);
398 assert_eq!(out.len(), 3); // input, param, matmul — no casts
399 }
400
401 #[test]
402 fn auto_mixed_inserts_casts_at_boundary() {
403 let mut g = Graph::new("test");
404 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
405 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
406 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
407 g.set_outputs(vec![mm]);
408
409 let pass = AutoMixedPrecision::new(PrecisionPolicy::AutoMixed);
410 let out = pass.run(g);
411
412 // Should have: input(f32), param(f32), cast(f32→f16) for x,
413 // cast(f32→f16) for w, matmul(f16), cast(f16→f32) for output.
414 // = 6 nodes total, with the final output being a Cast back to F32.
415 assert!(out.len() >= 6);
416 let final_node = out.node(out.outputs[0]);
417 assert!(matches!(final_node.op, Op::Cast { to: DType::F32 }));
418 }
419}