rlx_cpu/calibrate.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Activation-scale calibration for post-training INT8 quantization.
17//!
18//! The runtime-side counterpart to `rlx_opt::quant_insert`. Compiles a
19//! forward graph with calibration "tap" nodes wired in as outputs;
20//! the caller runs one batch at a time (filling the input slots
21//! between calls) and the [`Calibrator`] accumulates max-abs per tap.
22//! At the end, `scales()` returns `max_abs / 127.0` per tap (clamped
23//! up to `1e-6` to avoid division-by-zero on a flat tensor) — the
24//! per-tensor scale that maps the calibration range into i8.
25//!
26//! Why max-abs and not e.g. the 99th percentile? Max-abs matches what
27//! the cortexm Python trainer used to do (and what the Rust trainer
28//! that replaced it does). It's symmetric (zero zero-point), maps
29//! `[-max, +max] → [-127, 127]`, and gives the worst-case-correct
30//! quantization for activations whose distributions are roughly
31//! zero-centered. Percentile-based / KL-divergence calibration are
32//! follow-ups for later.
33
34use crate::arena::Arena;
35use crate::thunk::{ThunkSchedule, compile_thunks, execute_thunks};
36use rlx_ir::{Graph, NodeId};
37
38/// Compiled calibration harness. The graph is owned by the caller —
39/// we hold a reference and the compiled artifacts (arena + schedule).
40/// The caller writes inputs and parameters into `arena_mut()` between
41/// batches.
42pub struct Calibrator<'g> {
43 graph: &'g Graph,
44 arena: Arena,
45 sched: ThunkSchedule,
46 /// `(tap_node_id, num_elements)` pairs — cached so each `step`
47 /// doesn't re-walk the graph for shape info.
48 taps: Vec<(NodeId, usize)>,
49 /// Running max-abs per tap. Index aligns with the `taps` order
50 /// the caller passed to `new`.
51 max_abs: Vec<f32>,
52}
53
54impl<'g> Calibrator<'g> {
55 /// Build a calibrator over `graph` that records max-abs at each
56 /// `tap` after every `step()`. The graph must already have those
57 /// taps in its `outputs` list (so the memory planner keeps their
58 /// arena slots alive to end-of-execution); this constructor
59 /// asserts the precondition.
60 pub fn new(graph: &'g Graph, taps: Vec<NodeId>) -> Self {
61 for &t in &taps {
62 assert!(
63 graph.outputs.contains(&t),
64 "Calibrator: tap {t} must be in graph.outputs so its slot \
65 survives the run; add it via graph.set_outputs(…)"
66 );
67 }
68 let plan = rlx_opt::memory::plan_memory(graph);
69 let arena = Arena::from_plan(plan);
70 let sched = compile_thunks(graph, &arena);
71 let n = taps.len();
72 let taps_with_len: Vec<(NodeId, usize)> = taps
73 .into_iter()
74 .map(|t| {
75 let len = graph.node(t).shape.num_elements().unwrap_or(0);
76 (t, len)
77 })
78 .collect();
79 Self {
80 graph,
81 arena,
82 sched,
83 taps: taps_with_len,
84 max_abs: vec![0.0; n],
85 }
86 }
87
88 /// Mutable arena access — for writing inputs/params before each
89 /// `step()` and (typically once at startup) for filling
90 /// `Op::Constant` data via `rlx_runtime`'s loader.
91 pub fn arena_mut(&mut self) -> &mut Arena {
92 &mut self.arena
93 }
94
95 /// Read-only arena view — for reading the tap values manually if
96 /// the caller wants something fancier than max-abs.
97 pub fn arena(&self) -> &Arena {
98 &self.arena
99 }
100
101 /// Run one forward batch, then update each tap's running max-abs.
102 pub fn step(&mut self) {
103 execute_thunks(&self.sched, self.arena.raw_buf_mut());
104 for ((tap, len), max) in self.taps.iter().zip(self.max_abs.iter_mut()) {
105 let off = self.arena.byte_offset(*tap);
106 unsafe {
107 let p = self.arena.raw_buf().as_ptr().add(off) as *const f32;
108 for i in 0..*len {
109 let v = (*p.add(i)).abs();
110 if v > *max {
111 *max = v;
112 }
113 }
114 }
115 }
116 }
117
118 /// Per-tap max-abs accumulated so far (in input order).
119 pub fn max_abs(&self) -> &[f32] {
120 &self.max_abs
121 }
122
123 /// Per-tap scale = `max_abs / 127.0`, clamped up to `1e-6`.
124 /// Use directly as the `scale` for `Op::Quantize` / `Op::Dequantize`
125 /// or `rlx_opt::CalibrationEntry::per_tensor`.
126 pub fn scales(&self) -> Vec<f32> {
127 self.max_abs.iter().map(|m| (m / 127.0).max(1e-6)).collect()
128 }
129
130 /// Borrow the inner graph (for the caller to re-look-up NodeIds
131 /// after compilation).
132 pub fn graph(&self) -> &Graph {
133 self.graph
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use rlx_ir::op::*;
141 use rlx_ir::*;
142
143 /// One-tap calibration over a trivial graph: tap = `x` itself.
144 /// Hand-pack a couple of batches with known max-abs values and
145 /// verify `scales()` reflects them.
146 #[test]
147 fn calibrator_tracks_max_abs_across_batches() {
148 let f = DType::F32;
149 let mut g = Graph::new("calib_demo");
150 let x = g.input("x", Shape::new(&[4], f));
151 // Identity-ish: the tap *is* the input. Adding a Relu so the
152 // graph is non-trivial.
153 let y = g.activation(Activation::Relu, x, Shape::new(&[4], f));
154 g.set_outputs(vec![x, y]); // tap on `x` and `y`
155
156 let mut cal = Calibrator::new(&g, vec![x, y]);
157 // Batch 1: max-abs of x = 3.0; max-abs of y (Relu) = 3.0.
158 write_into(cal.arena_mut(), x, &[-3.0, 1.0, -2.0, 0.5]);
159 cal.step();
160 // Batch 2: x's max-abs grows to 7.0; y's stays since negatives
161 // get zeroed by Relu.
162 write_into(cal.arena_mut(), x, &[-7.0, 0.0, -7.0, -2.0]);
163 cal.step();
164 // Batch 3: both grow.
165 write_into(cal.arena_mut(), x, &[10.0, 0.0, 0.0, 5.0]);
166 cal.step();
167
168 let mx = cal.max_abs();
169 assert!((mx[0] - 10.0).abs() < 1e-6, "x max_abs: {}", mx[0]);
170 assert!((mx[1] - 10.0).abs() < 1e-6, "y max_abs: {}", mx[1]);
171
172 let s = cal.scales();
173 assert!((s[0] - 10.0 / 127.0).abs() < 1e-6);
174 assert!((s[1] - 10.0 / 127.0).abs() < 1e-6);
175 }
176
177 fn write_into(arena: &mut Arena, id: NodeId, data: &[f32]) {
178 let off = arena.byte_offset(id);
179 let buf = arena.raw_buf_mut();
180 unsafe {
181 let p = buf.as_mut_ptr().add(off) as *mut f32;
182 for (i, &v) in data.iter().enumerate() {
183 *p.add(i) = v;
184 }
185 }
186 }
187}