use rlx_ir::{Graph, Node, NodeId, Op, Shape};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct CalibrationEntry {
pub axis: Option<usize>,
pub scales: Vec<f32>,
pub zero_points: Vec<i32>,
}
impl CalibrationEntry {
pub fn per_tensor(scale: f32) -> Self {
Self {
axis: None,
scales: vec![scale],
zero_points: vec![0],
}
}
pub fn per_channel(axis: usize, scales: Vec<f32>) -> Self {
let n = scales.len();
Self {
axis: Some(axis),
scales,
zero_points: vec![0; n],
}
}
}
pub type CalibrationRecord = HashMap<NodeId, CalibrationEntry>;
pub fn insert_q_dq(graph: Graph, record: &CalibrationRecord) -> Graph {
let mut out = Graph::new(&graph.name);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
let mut tap_dq: HashMap<NodeId, NodeId> = HashMap::new();
for node in graph.nodes() {
let new_inputs: Vec<NodeId> = node
.inputs
.iter()
.map(|inp| {
tap_dq.get(inp).copied().unwrap_or(id_map[inp])
})
.collect();
let new_id = out.add_node(node.op.clone(), new_inputs, node.shape.clone());
id_map.insert(node.id, new_id);
if let Some(entry) = record.get(&node.id) {
let q = insert_quantize(new_id, node, entry, &mut out);
let dq = insert_dequantize(q, node, entry, &mut out);
tap_dq.insert(node.id, dq);
}
}
let new_outputs: Vec<NodeId> = graph
.outputs
.iter()
.map(|&id| tap_dq.get(&id).copied().unwrap_or(id_map[&id]))
.collect();
out.set_outputs(new_outputs);
out
}
fn insert_quantize(
src: NodeId,
src_node: &Node,
entry: &CalibrationEntry,
out: &mut Graph,
) -> NodeId {
let q_shape: Shape = src_node.shape.clone().with_dtype(rlx_ir::DType::I8);
out.add_node(
Op::Quantize {
axis: entry.axis,
scales: entry.scales.clone(),
zero_points: entry.zero_points.clone(),
},
vec![src],
q_shape,
)
}
fn insert_dequantize(
q: NodeId,
src_node: &Node,
entry: &CalibrationEntry,
out: &mut Graph,
) -> NodeId {
let dq_shape: Shape = src_node.shape.clone().with_dtype(rlx_ir::DType::F32);
out.add_node(
Op::Dequantize {
axis: entry.axis,
scales: entry.scales.clone(),
zero_points: entry.zero_points.clone(),
},
vec![q],
dq_shape,
)
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::*;
use rlx_ir::*;
#[test]
fn inserts_q_dq_pair_after_tap() {
let f = DType::F32;
let mut g = Graph::new("ptq_demo");
let x = g.input("x", Shape::new(&[4, 8], f));
let y = g.activation(Activation::Relu, x, Shape::new(&[4, 8], f));
let z = g.binary(BinaryOp::Add, y, y, Shape::new(&[4, 8], f));
g.set_outputs(vec![z]);
let mut record = CalibrationRecord::new();
record.insert(y, CalibrationEntry::per_tensor(0.05));
let g2 = insert_q_dq(g, &record);
assert!(
g2.nodes()
.iter()
.any(|n| matches!(n.op, Op::Quantize { .. }))
);
assert!(
g2.nodes()
.iter()
.any(|n| matches!(n.op, Op::Dequantize { .. }))
);
let add = g2
.nodes()
.iter()
.find(|n| matches!(n.op, Op::Binary(BinaryOp::Add)))
.expect("add node");
for &in_id in &add.inputs {
let in_op = &g2.node(in_id).op;
assert!(
matches!(in_op, Op::Dequantize { .. }),
"Add input should be Dequantize, got {in_op:?}"
);
}
}
#[test]
fn untagged_nodes_pass_through_unchanged() {
let f = DType::F32;
let mut g = Graph::new("no_taps");
let x = g.input("x", Shape::new(&[4], f));
let y = g.activation(Activation::Relu, x, Shape::new(&[4], f));
g.set_outputs(vec![y]);
let n_before = g.len();
let g2 = insert_q_dq(g, &CalibrationRecord::new());
assert_eq!(g2.len(), n_before);
}
}