use super::types::{
CalibrationData, QuantizationAnnotation, QuantizationParams, QuantizationScheme,
};
use crate::{FxGraph, Node, TorshResult};
use petgraph::graph::NodeIndex;
use std::collections::HashMap;
use torsh_core::error::TorshError;
pub struct QuantizationContext {
annotations: HashMap<NodeIndex, QuantizationAnnotation>,
calibration_data: HashMap<NodeIndex, CalibrationData>,
global_scheme: QuantizationScheme,
}
impl QuantizationContext {
pub fn new(scheme: QuantizationScheme) -> Self {
Self {
annotations: HashMap::new(),
calibration_data: HashMap::new(),
global_scheme: scheme,
}
}
pub fn annotate_node(&mut self, node: NodeIndex, annotation: QuantizationAnnotation) {
self.annotations.insert(node, annotation);
}
#[cfg(test)]
pub fn annotations(&self) -> &HashMap<NodeIndex, QuantizationAnnotation> {
&self.annotations
}
pub fn get_annotation(&self, node: NodeIndex) -> Option<&QuantizationAnnotation> {
self.annotations.get(&node)
}
pub fn start_calibration(&mut self, node: NodeIndex) {
self.calibration_data.insert(node, CalibrationData::new());
}
pub fn update_calibration(&mut self, node: NodeIndex, values: &[f32]) -> TorshResult<()> {
if let Some(data) = self.calibration_data.get_mut(&node) {
data.update(values);
Ok(())
} else {
Err(TorshError::InvalidArgument(format!(
"Calibration not started for node {:?}",
node
)))
}
}
pub fn finalize_calibration(&mut self, node: NodeIndex) -> TorshResult<QuantizationParams> {
if let Some(data) = self.calibration_data.remove(&node) {
Ok(data.compute_params(self.global_scheme))
} else {
Err(TorshError::InvalidArgument(format!(
"No calibration data for node {:?}",
node
)))
}
}
pub fn prepare_qat(&mut self, graph: &mut FxGraph) -> TorshResult<()> {
let mut insertions = Vec::new();
for (idx, node) in graph.nodes() {
if let Node::Call(op_name, _) = node {
if self.should_quantize_operation(op_name) {
insertions.push(idx);
}
}
}
for node_idx in insertions {
self.insert_fake_quantize_node(graph, node_idx)?;
}
Ok(())
}
pub fn quantize_graph(&self, graph: &mut FxGraph) -> TorshResult<()> {
let mut replacements = Vec::new();
for (idx, node) in graph.nodes() {
if let Node::Call(op_name, args) = node {
if self.should_quantize_operation(op_name) {
let quantized_op = self.get_quantized_operation(op_name);
replacements.push((idx, quantized_op, args.clone()));
}
}
}
for (idx, quantized_op, args) in replacements {
if let Some(node) = graph.graph.node_weight_mut(idx) {
*node = Node::Call(quantized_op, args);
}
}
Ok(())
}
fn should_quantize_operation(&self, op_name: &str) -> bool {
matches!(op_name, "linear" | "conv2d" | "matmul" | "add" | "mul")
}
fn get_quantized_operation(&self, op_name: &str) -> String {
match op_name {
"linear" => "quantized_linear".to_string(),
"conv2d" => "quantized_conv2d".to_string(),
"matmul" => "quantized_matmul".to_string(),
"add" => "quantized_add".to_string(),
"mul" => "quantized_mul".to_string(),
_ => format!("quantized_{op_name}"),
}
}
fn insert_fake_quantize_node(
&mut self,
_graph: &mut FxGraph,
target_idx: NodeIndex,
) -> TorshResult<()> {
let annotation = QuantizationAnnotation {
input_params: vec![Some(QuantizationParams::symmetric(self.global_scheme, 1.0))],
output_params: Some(QuantizationParams::symmetric(self.global_scheme, 1.0)),
calibration_data: None,
};
self.annotate_node(target_idx, annotation);
Ok(())
}
}