Skip to main content

rlx_cpu/
calibrate.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Activation-scale calibration for post-training INT8 quantization.
17//!
18//! The runtime-side counterpart to `rlx_opt::quant_insert`. Compiles a
19//! forward graph with calibration "tap" nodes wired in as outputs;
20//! the caller runs one batch at a time (filling the input slots
21//! between calls) and the [`Calibrator`] accumulates max-abs per tap.
22//! At the end, `scales()` returns `max_abs / 127.0` per tap (clamped
23//! up to `1e-6` to avoid division-by-zero on a flat tensor) — the
24//! per-tensor scale that maps the calibration range into i8.
25//!
26//! Why max-abs and not e.g. the 99th percentile? Max-abs matches what
27//! the cortexm Python trainer used to do (and what the Rust trainer
28//! that replaced it does). It's symmetric (zero zero-point), maps
29//! `[-max, +max] → [-127, 127]`, and gives the worst-case-correct
30//! quantization for activations whose distributions are roughly
31//! zero-centered. Percentile-based / KL-divergence calibration are
32//! follow-ups for later.
33
34use crate::arena::Arena;
35use crate::thunk::{ThunkSchedule, compile_thunks, execute_thunks};
36use rlx_ir::{Graph, NodeId};
37
38/// Compiled calibration harness. The graph is owned by the caller —
39/// we hold a reference and the compiled artifacts (arena + schedule).
40/// The caller writes inputs and parameters into `arena_mut()` between
41/// batches.
42pub struct Calibrator<'g> {
43    graph: &'g Graph,
44    arena: Arena,
45    sched: ThunkSchedule,
46    /// `(tap_node_id, num_elements)` pairs — cached so each `step`
47    /// doesn't re-walk the graph for shape info.
48    taps: Vec<(NodeId, usize)>,
49    /// Running max-abs per tap. Index aligns with the `taps` order
50    /// the caller passed to `new`.
51    max_abs: Vec<f32>,
52}
53
54impl<'g> Calibrator<'g> {
55    /// Build a calibrator over `graph` that records max-abs at each
56    /// `tap` after every `step()`. The graph must already have those
57    /// taps in its `outputs` list (so the memory planner keeps their
58    /// arena slots alive to end-of-execution); this constructor
59    /// asserts the precondition.
60    pub fn new(graph: &'g Graph, taps: Vec<NodeId>) -> Self {
61        for &t in &taps {
62            assert!(
63                graph.outputs.contains(&t),
64                "Calibrator: tap {t} must be in graph.outputs so its slot \
65                 survives the run; add it via graph.set_outputs(…)"
66            );
67        }
68        let plan = rlx_opt::memory::plan_memory(graph);
69        let arena = Arena::from_plan(plan);
70        let sched = compile_thunks(graph, &arena);
71        let n = taps.len();
72        let taps_with_len: Vec<(NodeId, usize)> = taps
73            .into_iter()
74            .map(|t| {
75                let len = graph.node(t).shape.num_elements().unwrap_or(0);
76                (t, len)
77            })
78            .collect();
79        Self {
80            graph,
81            arena,
82            sched,
83            taps: taps_with_len,
84            max_abs: vec![0.0; n],
85        }
86    }
87
88    /// Mutable arena access — for writing inputs/params before each
89    /// `step()` and (typically once at startup) for filling
90    /// `Op::Constant` data via `rlx_runtime`'s loader.
91    pub fn arena_mut(&mut self) -> &mut Arena {
92        &mut self.arena
93    }
94
95    /// Read-only arena view — for reading the tap values manually if
96    /// the caller wants something fancier than max-abs.
97    pub fn arena(&self) -> &Arena {
98        &self.arena
99    }
100
101    /// Run one forward batch, then update each tap's running max-abs.
102    pub fn step(&mut self) {
103        execute_thunks(&self.sched, self.arena.raw_buf_mut());
104        for ((tap, len), max) in self.taps.iter().zip(self.max_abs.iter_mut()) {
105            let off = self.arena.byte_offset(*tap);
106            unsafe {
107                let p = self.arena.raw_buf().as_ptr().add(off) as *const f32;
108                for i in 0..*len {
109                    let v = (*p.add(i)).abs();
110                    if v > *max {
111                        *max = v;
112                    }
113                }
114            }
115        }
116    }
117
118    /// Per-tap max-abs accumulated so far (in input order).
119    pub fn max_abs(&self) -> &[f32] {
120        &self.max_abs
121    }
122
123    /// Per-tap scale = `max_abs / 127.0`, clamped up to `1e-6`.
124    /// Use directly as the `scale` for `Op::Quantize` / `Op::Dequantize`
125    /// or `rlx_opt::CalibrationEntry::per_tensor`.
126    pub fn scales(&self) -> Vec<f32> {
127        self.max_abs.iter().map(|m| (m / 127.0).max(1e-6)).collect()
128    }
129
130    /// Borrow the inner graph (for the caller to re-look-up NodeIds
131    /// after compilation).
132    pub fn graph(&self) -> &Graph {
133        self.graph
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use rlx_ir::op::*;
141    use rlx_ir::*;
142
143    /// One-tap calibration over a trivial graph: tap = `x` itself.
144    /// Hand-pack a couple of batches with known max-abs values and
145    /// verify `scales()` reflects them.
146    #[test]
147    fn calibrator_tracks_max_abs_across_batches() {
148        let f = DType::F32;
149        let mut g = Graph::new("calib_demo");
150        let x = g.input("x", Shape::new(&[4], f));
151        // Identity-ish: the tap *is* the input. Adding a Relu so the
152        // graph is non-trivial.
153        let y = g.activation(Activation::Relu, x, Shape::new(&[4], f));
154        g.set_outputs(vec![x, y]); // tap on `x` and `y`
155
156        let mut cal = Calibrator::new(&g, vec![x, y]);
157        // Batch 1: max-abs of x = 3.0; max-abs of y (Relu) = 3.0.
158        write_into(cal.arena_mut(), x, &[-3.0, 1.0, -2.0, 0.5]);
159        cal.step();
160        // Batch 2: x's max-abs grows to 7.0; y's stays since negatives
161        // get zeroed by Relu.
162        write_into(cal.arena_mut(), x, &[-7.0, 0.0, -7.0, -2.0]);
163        cal.step();
164        // Batch 3: both grow.
165        write_into(cal.arena_mut(), x, &[10.0, 0.0, 0.0, 5.0]);
166        cal.step();
167
168        let mx = cal.max_abs();
169        assert!((mx[0] - 10.0).abs() < 1e-6, "x max_abs: {}", mx[0]);
170        assert!((mx[1] - 10.0).abs() < 1e-6, "y max_abs: {}", mx[1]);
171
172        let s = cal.scales();
173        assert!((s[0] - 10.0 / 127.0).abs() < 1e-6);
174        assert!((s[1] - 10.0 / 127.0).abs() < 1e-6);
175    }
176
177    fn write_into(arena: &mut Arena, id: NodeId, data: &[f32]) {
178        let off = arena.byte_offset(id);
179        let buf = arena.raw_buf_mut();
180        unsafe {
181            let p = buf.as_mut_ptr().add(off) as *mut f32;
182            for (i, &v) in data.iter().enumerate() {
183                *p.add(i) = v;
184            }
185        }
186    }
187}