use ruvector_cnn::quantize::{
CalibrationHistogram, ComputationGraph, NodeParams, NodeType, QuantizationParams,
fuse_batchnorm_to_conv, fuse_hardswish, fuse_relu, fuse_zp_to_bias,
generate_hardswish_lut, insert_qdq_nodes,
};
use std::collections::HashMap;
#[test]
fn test_complete_graph_optimization_pipeline() {
let mut graph = ComputationGraph::new();
let input_id = graph.add_node(NodeType::Input, NodeParams::None);
let conv1_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0; 16], bias: Some(vec![0.0, 0.0]),
in_channels: 2,
out_channels: 2,
kernel_size: 2,
},
);
let bn_id = graph.add_node(
NodeType::BatchNorm,
NodeParams::BatchNorm {
gamma: vec![2.0, 3.0],
beta: vec![0.1, 0.2],
mean: vec![1.0, 2.0],
var: vec![1.0, 4.0],
eps: 1e-5,
},
);
let relu_id = graph.add_node(NodeType::ReLU, NodeParams::Activation);
let conv2_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0; 8],
bias: None,
in_channels: 2,
out_channels: 1,
kernel_size: 2,
},
);
let hs_id = graph.add_node(NodeType::HardSwish, NodeParams::Activation);
let output_id = graph.add_node(NodeType::Output, NodeParams::None);
graph.connect(input_id, conv1_id);
graph.connect(conv1_id, bn_id);
graph.connect(bn_id, relu_id);
graph.connect(relu_id, conv2_id);
graph.connect(conv2_id, hs_id);
graph.connect(hs_id, output_id);
assert_eq!(graph.nodes.len(), 7);
let bn_fused = fuse_batchnorm_to_conv(&mut graph);
assert_eq!(bn_fused, 1);
assert_eq!(graph.nodes.len(), 6);
let relu_fused = fuse_relu(&mut graph);
assert_eq!(relu_fused, 1);
assert_eq!(graph.nodes.len(), 5);
let hs_fused = fuse_hardswish(&mut graph);
assert_eq!(hs_fused, 1);
assert_eq!(graph.nodes.len(), 4);
assert_eq!(graph.nodes.len(), 4);
}
#[test]
fn test_zero_point_fusion() {
let mut graph = ComputationGraph::new();
let input_id = graph.add_node(NodeType::Input, NodeParams::None);
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0, 2.0, 3.0, 4.0],
bias: Some(vec![1.0, 2.0]),
in_channels: 1,
out_channels: 2,
kernel_size: 1,
},
);
graph.connect(input_id, conv_id);
let mut quant_params = HashMap::new();
quant_params.insert(
input_id,
QuantizationParams {
scale: 0.1,
zero_point: 10,
min_val: -12.8,
max_val: 12.7,
num_bins: 256,
},
);
let fused = fuse_zp_to_bias(&mut graph, &quant_params);
assert_eq!(fused, 1);
let conv_node = graph.get_node(conv_id).unwrap();
if let NodeParams::Conv2d { bias, .. } = &conv_node.params {
let bias = bias.as_ref().unwrap();
assert!((bias[0] - (-29.0)).abs() < 0.01);
assert!((bias[1] - (-68.0)).abs() < 0.01);
} else {
panic!("Expected Conv2d params");
}
}
#[test]
fn test_quantize_dequantize_insertion() {
let mut graph = ComputationGraph::new();
let input_id = graph.add_node(NodeType::Input, NodeParams::None);
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0; 4],
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 2,
},
);
let output_id = graph.add_node(NodeType::Output, NodeParams::None);
graph.connect(input_id, conv_id);
graph.connect(conv_id, output_id);
let mut quant_params = HashMap::new();
quant_params.insert(
conv_id,
QuantizationParams {
scale: 0.1,
zero_point: 0,
min_val: -12.8,
max_val: 12.7,
num_bins: 256,
},
);
let inserted = insert_qdq_nodes(&mut graph, &quant_params);
assert_eq!(inserted, 2);
assert_eq!(graph.nodes.len(), 5);
let conv_node = graph.get_node(conv_id).unwrap();
assert_eq!(conv_node.inputs.len(), 1);
assert_eq!(conv_node.outputs.len(), 1);
let q_id = conv_node.inputs[0];
let q_node = graph.get_node(q_id).unwrap();
assert_eq!(q_node.node_type, NodeType::Quantize);
let dq_id = conv_node.outputs[0];
let dq_node = graph.get_node(dq_id).unwrap();
assert_eq!(dq_node.node_type, NodeType::Dequantize);
}
#[test]
fn test_hardswish_lut_generation() {
let scale = 0.1;
let zero_point = 0;
let lut = generate_hardswish_lut(scale, zero_point);
let idx_0 = 128; assert_eq!(lut[idx_0], 0);
let idx_neg = 0; assert_eq!(lut[idx_neg], 0);
let idx_pos = 255; let x_pos = (lut[idx_pos] as i32 - zero_point) as f32 * scale;
assert!(x_pos > 10.0);
let idx_mid = (15 - zero_point + 128) as usize; let x_mid = (lut[idx_mid] as i32 - zero_point) as f32 * scale;
assert!((x_mid - 1.125).abs() < 0.3);
}
#[test]
fn test_calibration_histogram() {
let mut hist = CalibrationHistogram::new(-10.0, 10.0, 100);
for _ in 0..100 {
hist.add(5.0);
}
for _ in 0..50 {
hist.add(-5.0);
}
let params = hist.compute_quantization_params();
assert_eq!(params.zero_point, 0);
let expected_scale = 10.0 / 127.0;
assert!((params.scale - expected_scale).abs() < 0.01);
}
#[test]
fn test_batchnorm_fusion_preserves_semantics() {
let mut graph = ComputationGraph::new();
let input_id = graph.add_node(NodeType::Input, NodeParams::None);
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![2.0, 3.0], bias: Some(vec![1.0]),
in_channels: 1,
out_channels: 1,
kernel_size: 1,
},
);
let bn_id = graph.add_node(
NodeType::BatchNorm,
NodeParams::BatchNorm {
gamma: vec![2.0],
beta: vec![0.5],
mean: vec![3.0],
var: vec![4.0],
eps: 1e-5,
},
);
graph.connect(input_id, conv_id);
graph.connect(conv_id, bn_id);
let scale = 2.0 / (4.0 + 1e-5_f32).sqrt(); let expected_w0 = 2.0 * scale; let expected_w1 = 3.0 * scale; let expected_bias = (1.0 - 3.0) * scale + 0.5;
fuse_batchnorm_to_conv(&mut graph);
let conv_node = graph.get_node(conv_id).unwrap();
if let NodeParams::Conv2d { weights, bias, .. } = &conv_node.params {
assert!((weights[0] - expected_w0).abs() < 0.01);
assert!((weights[1] - expected_w1).abs() < 0.01);
assert!((bias.as_ref().unwrap()[0] - expected_bias).abs() < 0.01);
}
}
#[test]
fn test_multi_output_graph() {
let mut graph = ComputationGraph::new();
let input_id = graph.add_node(NodeType::Input, NodeParams::None);
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0; 4],
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 2,
},
);
let relu_id = graph.add_node(NodeType::ReLU, NodeParams::Activation);
let hs_id = graph.add_node(NodeType::HardSwish, NodeParams::Activation);
graph.connect(input_id, conv_id);
graph.connect(conv_id, relu_id);
graph.connect(conv_id, hs_id);
assert_eq!(graph.get_node(conv_id).unwrap().outputs.len(), 2);
}