use crate::arena::Arena;
use crate::thunk::{ThunkSchedule, compile_thunks, execute_thunks};
use rlx_ir::{Graph, NodeId};
pub struct Calibrator<'g> {
graph: &'g Graph,
arena: Arena,
sched: ThunkSchedule,
taps: Vec<(NodeId, usize)>,
max_abs: Vec<f32>,
}
impl<'g> Calibrator<'g> {
pub fn new(graph: &'g Graph, taps: Vec<NodeId>) -> Self {
for &t in &taps {
assert!(
graph.outputs.contains(&t),
"Calibrator: tap {t} must be in graph.outputs so its slot \
survives the run; add it via graph.set_outputs(…)"
);
}
let plan = rlx_opt::memory::plan_memory(graph);
let arena = Arena::from_plan(plan);
let sched = compile_thunks(graph, &arena);
let n = taps.len();
let taps_with_len: Vec<(NodeId, usize)> = taps
.into_iter()
.map(|t| {
let len = graph.node(t).shape.num_elements().unwrap_or(0);
(t, len)
})
.collect();
Self {
graph,
arena,
sched,
taps: taps_with_len,
max_abs: vec![0.0; n],
}
}
pub fn arena_mut(&mut self) -> &mut Arena {
&mut self.arena
}
pub fn arena(&self) -> &Arena {
&self.arena
}
pub fn step(&mut self) {
execute_thunks(&self.sched, self.arena.raw_buf_mut());
for ((tap, len), max) in self.taps.iter().zip(self.max_abs.iter_mut()) {
let off = self.arena.byte_offset(*tap);
unsafe {
let p = self.arena.raw_buf().as_ptr().add(off) as *const f32;
for i in 0..*len {
let v = (*p.add(i)).abs();
if v > *max {
*max = v;
}
}
}
}
}
pub fn max_abs(&self) -> &[f32] {
&self.max_abs
}
pub fn scales(&self) -> Vec<f32> {
self.max_abs.iter().map(|m| (m / 127.0).max(1e-6)).collect()
}
pub fn graph(&self) -> &Graph {
self.graph
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::*;
use rlx_ir::*;
#[test]
fn calibrator_tracks_max_abs_across_batches() {
let f = DType::F32;
let mut g = Graph::new("calib_demo");
let x = g.input("x", Shape::new(&[4], f));
let y = g.activation(Activation::Relu, x, Shape::new(&[4], f));
g.set_outputs(vec![x, y]);
let mut cal = Calibrator::new(&g, vec![x, y]);
write_into(cal.arena_mut(), x, &[-3.0, 1.0, -2.0, 0.5]);
cal.step();
write_into(cal.arena_mut(), x, &[-7.0, 0.0, -7.0, -2.0]);
cal.step();
write_into(cal.arena_mut(), x, &[10.0, 0.0, 0.0, 5.0]);
cal.step();
let mx = cal.max_abs();
assert!((mx[0] - 10.0).abs() < 1e-6, "x max_abs: {}", mx[0]);
assert!((mx[1] - 10.0).abs() < 1e-6, "y max_abs: {}", mx[1]);
let s = cal.scales();
assert!((s[0] - 10.0 / 127.0).abs() < 1e-6);
assert!((s[1] - 10.0 / 127.0).abs() < 1e-6);
}
fn write_into(arena: &mut Arena, id: NodeId, data: &[f32]) {
let off = arena.byte_offset(id);
let buf = arena.raw_buf_mut();
unsafe {
let p = buf.as_mut_ptr().add(off) as *mut f32;
for (i, &v) in data.iter().enumerate() {
*p.add(i) = v;
}
}
}
}