Skip to main content

rlx_compile/
quant_insert.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Quantize/dequantize insertion pass.
17//!
18//! The IR-rewrite half of post-training quantization. Given a
19//! per-tensor or per-channel calibration record (produced by running
20//! forward on a sample batch, see `rlx_cpu::calibrate`), this pass
21//! walks the graph and inserts `Op::Quantize → Op::Dequantize` pairs
22//! immediately downstream of each tagged node. Consumers of the
23//! original tap node are rewired to read the dequantized result, so
24//! everything past the tap sees an INT8 round-tripped activation /
25//! weight while the rest of the graph stays in fp32.
26//!
27//! Why a Q/DQ pair instead of switching the whole subgraph to INT8?
28//! For PTQ this is the standard "fake-quant" pattern — the IR stays
29//! coherent in fp32, but each tap loses one quant step of precision
30//! to simulate the on-device int8 path. Real INT8-arithmetic kernels
31//! (`Op::DequantMatMul`, etc.) replace specific Q/DQ-bracketed regions
32//! later in the pipeline; this pass just produces the canonical form.
33//!
34//! Scope intentionally narrow: insert-only, no measurement. The
35//! caller is responsible for filling `CalibrationRecord` from
36//! whatever execution path it has access to.
37
38use rlx_ir::{Graph, Node, NodeId, Op, Shape};
39use std::collections::HashMap;
40
41/// One calibrated quant entry per tap. `axis = None` is per-tensor;
42/// `axis = Some(d)` is per-channel along axis `d`, in which case
43/// `scales` and `zero_points` must each have length `tap.shape.dim(d)`.
44#[derive(Debug, Clone)]
45pub struct CalibrationEntry {
46    pub axis: Option<usize>,
47    pub scales: Vec<f32>,
48    pub zero_points: Vec<i32>,
49}
50
51impl CalibrationEntry {
52    /// Convenience constructor for the per-tensor symmetric case.
53    pub fn per_tensor(scale: f32) -> Self {
54        Self {
55            axis: None,
56            scales: vec![scale],
57            zero_points: vec![0],
58        }
59    }
60
61    /// Per-channel symmetric (`zp = 0`) along `axis`.
62    pub fn per_channel(axis: usize, scales: Vec<f32>) -> Self {
63        let n = scales.len();
64        Self {
65            axis: Some(axis),
66            scales,
67            zero_points: vec![0; n],
68        }
69    }
70}
71
72/// Map of tap NodeId → calibrated quant params.
73pub type CalibrationRecord = HashMap<NodeId, CalibrationEntry>;
74
75/// Insert `Quantize → Dequantize` pairs at every tap in `record`.
76/// Returns a graph where each tagged node is followed by a
77/// `Quantize → Dequantize` pair, and every consumer of the original
78/// tap reads from the dequantized output instead.
79///
80/// One-pass build: when we copy a consumer node, we rewrite any input
81/// edge that refers to a tap so it points at the tap's DQ instead.
82/// The Q and DQ nodes themselves are exempt (we identify them via
83/// their `Op::Quantize` / `Op::Dequantize` discriminants — the tap's
84/// raw value still flows in to the Quantize).
85pub fn insert_q_dq(graph: Graph, record: &CalibrationRecord) -> Graph {
86    let mut out = Graph::new(&graph.name);
87    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
88    // For each old-graph tap NodeId, the NodeId of its dequantized
89    // replacement in `out`. Consumers of the tap rewrite their inputs
90    // to read from this id instead of the raw tap.
91    let mut tap_dq: HashMap<NodeId, NodeId> = HashMap::new();
92
93    for node in graph.nodes() {
94        // Translate `node.inputs` for the *new* graph, rerouting any
95        // tap reference to the tap's DQ.
96        let new_inputs: Vec<NodeId> = node
97            .inputs
98            .iter()
99            .map(|inp| {
100                // The Q node we'll insert next iteration is the only
101                // legal raw-tap consumer; everything else routes through
102                // DQ. Since we haven't placed the Q yet (it's inserted
103                // *after* the tap node it wraps), the only nodes we
104                // consider "Q" here are nodes we ourselves emit below.
105                // No risk of self-reference: we route via tap_dq only
106                // when it's already populated — i.e. for nodes that
107                // come after their producer was tapped.
108                tap_dq.get(inp).copied().unwrap_or(id_map[inp])
109            })
110            .collect();
111
112        let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
113        id_map.insert(node.id, new_id);
114
115        if let Some(entry) = record.get(&node.id) {
116            let q = insert_quantize(new_id, node, entry, &mut out);
117            let dq = insert_dequantize(q, node, entry, &mut out);
118            tap_dq.insert(node.id, dq);
119        }
120    }
121
122    // Outputs: if a tap is also a graph output, return the DQ.
123    let new_outputs: Vec<NodeId> = graph
124        .outputs
125        .iter()
126        .map(|&id| tap_dq.get(&id).copied().unwrap_or(id_map[&id]))
127        .collect();
128    out.set_outputs(new_outputs);
129    out
130}
131
132fn insert_quantize(
133    src: NodeId,
134    src_node: &Node,
135    entry: &CalibrationEntry,
136    out: &mut Graph,
137) -> NodeId {
138    let q_shape: Shape = src_node.shape.clone().with_dtype(rlx_ir::DType::I8);
139    out.add_node(
140        Op::Quantize {
141            axis: entry.axis,
142            scales: entry.scales.clone(),
143            zero_points: entry.zero_points.clone(),
144        },
145        vec![src],
146        q_shape,
147    )
148}
149
150fn insert_dequantize(
151    q: NodeId,
152    src_node: &Node,
153    entry: &CalibrationEntry,
154    out: &mut Graph,
155) -> NodeId {
156    let dq_shape: Shape = src_node.shape.clone().with_dtype(rlx_ir::DType::F32);
157    out.add_node(
158        Op::Dequantize {
159            axis: entry.axis,
160            scales: entry.scales.clone(),
161            zero_points: entry.zero_points.clone(),
162        },
163        vec![q],
164        dq_shape,
165    )
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use rlx_ir::op::*;
172    use rlx_ir::*;
173
174    #[test]
175    fn inserts_q_dq_pair_after_tap() {
176        let f = DType::F32;
177        let mut g = Graph::new("ptq_demo");
178        let x = g.input("x", Shape::new(&[4, 8], f));
179        let y = g.activation(Activation::Relu, x, Shape::new(&[4, 8], f));
180        let z = g.binary(BinaryOp::Add, y, y, Shape::new(&[4, 8], f));
181        g.set_outputs(vec![z]);
182
183        // Tag `y` for per-tensor quantization.
184        let mut record = CalibrationRecord::new();
185        record.insert(y, CalibrationEntry::per_tensor(0.05));
186
187        let g2 = insert_q_dq(g, &record);
188
189        // Expect: a Quantize and a Dequantize node now exist.
190        assert!(
191            g2.nodes()
192                .iter()
193                .any(|n| matches!(n.op, Op::Quantize { .. }))
194        );
195        assert!(
196            g2.nodes()
197                .iter()
198                .any(|n| matches!(n.op, Op::Dequantize { .. }))
199        );
200
201        // The Add node's inputs should now reference the Dequantize
202        // output, not the Relu output. Find the Add and check.
203        let add = g2
204            .nodes()
205            .iter()
206            .find(|n| matches!(n.op, Op::Binary(BinaryOp::Add)))
207            .expect("add node");
208        for &in_id in &add.inputs {
209            let in_op = &g2.node(in_id).op;
210            assert!(
211                matches!(in_op, Op::Dequantize { .. }),
212                "Add input should be Dequantize, got {in_op:?}"
213            );
214        }
215    }
216
217    #[test]
218    fn untagged_nodes_pass_through_unchanged() {
219        let f = DType::F32;
220        let mut g = Graph::new("no_taps");
221        let x = g.input("x", Shape::new(&[4], f));
222        let y = g.activation(Activation::Relu, x, Shape::new(&[4], f));
223        g.set_outputs(vec![y]);
224
225        let n_before = g.len();
226        let g2 = insert_q_dq(g, &CalibrationRecord::new());
227        assert_eq!(g2.len(), n_before);
228    }
229}