rlx_compile/quant_insert.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Quantize/dequantize insertion pass.
17//!
18//! The IR-rewrite half of post-training quantization. Given a
19//! per-tensor or per-channel calibration record (produced by running
20//! forward on a sample batch, see `rlx_cpu::calibrate`), this pass
21//! walks the graph and inserts `Op::Quantize → Op::Dequantize` pairs
22//! immediately downstream of each tagged node. Consumers of the
23//! original tap node are rewired to read the dequantized result, so
24//! everything past the tap sees an INT8 round-tripped activation /
25//! weight while the rest of the graph stays in fp32.
26//!
27//! Why a Q/DQ pair instead of switching the whole subgraph to INT8?
28//! For PTQ this is the standard "fake-quant" pattern — the IR stays
29//! coherent in fp32, but each tap loses one quant step of precision
30//! to simulate the on-device int8 path. Real INT8-arithmetic kernels
31//! (`Op::DequantMatMul`, etc.) replace specific Q/DQ-bracketed regions
32//! later in the pipeline; this pass just produces the canonical form.
33//!
34//! Scope intentionally narrow: insert-only, no measurement. The
35//! caller is responsible for filling `CalibrationRecord` from
36//! whatever execution path it has access to.
37
38use rlx_ir::{Graph, Node, NodeId, Op, Shape};
39use std::collections::HashMap;
40
41/// One calibrated quant entry per tap. `axis = None` is per-tensor;
42/// `axis = Some(d)` is per-channel along axis `d`, in which case
43/// `scales` and `zero_points` must each have length `tap.shape.dim(d)`.
44#[derive(Debug, Clone)]
45pub struct CalibrationEntry {
46 pub axis: Option<usize>,
47 pub scales: Vec<f32>,
48 pub zero_points: Vec<i32>,
49}
50
51impl CalibrationEntry {
52 /// Convenience constructor for the per-tensor symmetric case.
53 pub fn per_tensor(scale: f32) -> Self {
54 Self {
55 axis: None,
56 scales: vec![scale],
57 zero_points: vec![0],
58 }
59 }
60
61 /// Per-channel symmetric (`zp = 0`) along `axis`.
62 pub fn per_channel(axis: usize, scales: Vec<f32>) -> Self {
63 let n = scales.len();
64 Self {
65 axis: Some(axis),
66 scales,
67 zero_points: vec![0; n],
68 }
69 }
70}
71
72/// Map of tap NodeId → calibrated quant params.
73pub type CalibrationRecord = HashMap<NodeId, CalibrationEntry>;
74
75/// Insert `Quantize → Dequantize` pairs at every tap in `record`.
76/// Returns a graph where each tagged node is followed by a
77/// `Quantize → Dequantize` pair, and every consumer of the original
78/// tap reads from the dequantized output instead.
79///
80/// One-pass build: when we copy a consumer node, we rewrite any input
81/// edge that refers to a tap so it points at the tap's DQ instead.
82/// The Q and DQ nodes themselves are exempt (we identify them via
83/// their `Op::Quantize` / `Op::Dequantize` discriminants — the tap's
84/// raw value still flows in to the Quantize).
85pub fn insert_q_dq(graph: Graph, record: &CalibrationRecord) -> Graph {
86 let mut out = Graph::new(&graph.name);
87 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
88 // For each old-graph tap NodeId, the NodeId of its dequantized
89 // replacement in `out`. Consumers of the tap rewrite their inputs
90 // to read from this id instead of the raw tap.
91 let mut tap_dq: HashMap<NodeId, NodeId> = HashMap::new();
92
93 for node in graph.nodes() {
94 // Translate `node.inputs` for the *new* graph, rerouting any
95 // tap reference to the tap's DQ.
96 let new_inputs: Vec<NodeId> = node
97 .inputs
98 .iter()
99 .map(|inp| {
100 // The Q node we'll insert next iteration is the only
101 // legal raw-tap consumer; everything else routes through
102 // DQ. Since we haven't placed the Q yet (it's inserted
103 // *after* the tap node it wraps), the only nodes we
104 // consider "Q" here are nodes we ourselves emit below.
105 // No risk of self-reference: we route via tap_dq only
106 // when it's already populated — i.e. for nodes that
107 // come after their producer was tapped.
108 tap_dq.get(inp).copied().unwrap_or(id_map[inp])
109 })
110 .collect();
111
112 let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
113 id_map.insert(node.id, new_id);
114
115 if let Some(entry) = record.get(&node.id) {
116 let q = insert_quantize(new_id, node, entry, &mut out);
117 let dq = insert_dequantize(q, node, entry, &mut out);
118 tap_dq.insert(node.id, dq);
119 }
120 }
121
122 // Outputs: if a tap is also a graph output, return the DQ.
123 let new_outputs: Vec<NodeId> = graph
124 .outputs
125 .iter()
126 .map(|&id| tap_dq.get(&id).copied().unwrap_or(id_map[&id]))
127 .collect();
128 out.set_outputs(new_outputs);
129 out
130}
131
132fn insert_quantize(
133 src: NodeId,
134 src_node: &Node,
135 entry: &CalibrationEntry,
136 out: &mut Graph,
137) -> NodeId {
138 let q_shape: Shape = src_node.shape.clone().with_dtype(rlx_ir::DType::I8);
139 out.add_node(
140 Op::Quantize {
141 axis: entry.axis,
142 scales: entry.scales.clone(),
143 zero_points: entry.zero_points.clone(),
144 },
145 vec![src],
146 q_shape,
147 )
148}
149
150fn insert_dequantize(
151 q: NodeId,
152 src_node: &Node,
153 entry: &CalibrationEntry,
154 out: &mut Graph,
155) -> NodeId {
156 let dq_shape: Shape = src_node.shape.clone().with_dtype(rlx_ir::DType::F32);
157 out.add_node(
158 Op::Dequantize {
159 axis: entry.axis,
160 scales: entry.scales.clone(),
161 zero_points: entry.zero_points.clone(),
162 },
163 vec![q],
164 dq_shape,
165 )
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use rlx_ir::op::*;
172 use rlx_ir::*;
173
174 #[test]
175 fn inserts_q_dq_pair_after_tap() {
176 let f = DType::F32;
177 let mut g = Graph::new("ptq_demo");
178 let x = g.input("x", Shape::new(&[4, 8], f));
179 let y = g.activation(Activation::Relu, x, Shape::new(&[4, 8], f));
180 let z = g.binary(BinaryOp::Add, y, y, Shape::new(&[4, 8], f));
181 g.set_outputs(vec![z]);
182
183 // Tag `y` for per-tensor quantization.
184 let mut record = CalibrationRecord::new();
185 record.insert(y, CalibrationEntry::per_tensor(0.05));
186
187 let g2 = insert_q_dq(g, &record);
188
189 // Expect: a Quantize and a Dequantize node now exist.
190 assert!(
191 g2.nodes()
192 .iter()
193 .any(|n| matches!(n.op, Op::Quantize { .. }))
194 );
195 assert!(
196 g2.nodes()
197 .iter()
198 .any(|n| matches!(n.op, Op::Dequantize { .. }))
199 );
200
201 // The Add node's inputs should now reference the Dequantize
202 // output, not the Relu output. Find the Add and check.
203 let add = g2
204 .nodes()
205 .iter()
206 .find(|n| matches!(n.op, Op::Binary(BinaryOp::Add)))
207 .expect("add node");
208 for &in_id in &add.inputs {
209 let in_op = &g2.node(in_id).op;
210 assert!(
211 matches!(in_op, Op::Dequantize { .. }),
212 "Add input should be Dequantize, got {in_op:?}"
213 );
214 }
215 }
216
217 #[test]
218 fn untagged_nodes_pass_through_unchanged() {
219 let f = DType::F32;
220 let mut g = Graph::new("no_taps");
221 let x = g.input("x", Shape::new(&[4], f));
222 let y = g.activation(Activation::Relu, x, Shape::new(&[4], f));
223 g.set_outputs(vec![y]);
224
225 let n_before = g.len();
226 let g2 = insert_q_dq(g, &CalibrationRecord::new());
227 assert_eq!(g2.len(), n_before);
228 }
229}