use scirs2_core::ndarray::Array2;
use scirs2_core::random::{rng, Rng, RngExt};
use scirs2_core::random::{Distribution, Normal};
use scirs2_linalg::quantization::calibration::{
calibrate_matrix, CalibrationConfig, CalibrationMethod,
};
use scirs2_linalg::quantization::{dequantize_matrix, quantize_matrix, quantized_matmul};
#[allow(dead_code)]
fn main() {
println!("Quantization for Machine Learning Example");
println!("=========================================\n");
println!("Creating synthetic model weights and activations...");
let weights = create_model_weights(64, 32);
let activations = create_activations(10, 64);
println!("\nComparing matrix multiplication accuracy with different calibration methods:");
compare_matmul_accuracy(&weights, &activations, 8);
println!("\nComparing quantization bit widths for matrix multiplication:");
compare_bit_widths_matmul(&weights, &activations);
println!("\nDemonstrating mixed precision quantization:");
demonstrate_mixed_precision(&weights, &activations);
}
#[allow(dead_code)]
fn create_model_weights(inputsize: usize, outputsize: usize) -> Array2<f32> {
let mut rng = scirs2_core::random::rng();
let normal = Normal::new(0.0, 0.1).expect("Operation failed");
let mut weights = Array2::zeros((outputsize, inputsize));
for i in 0..outputsize {
for j in 0..inputsize {
weights[[i, j]] = normal.sample(&mut rng);
}
}
weights
}
#[allow(dead_code)]
fn create_activations(_batchsize: usize, featuresize: usize) -> Array2<f32> {
let mut rng = scirs2_core::random::rng();
let mut activations = Array2::zeros((_batchsize, featuresize));
for i in 0.._batchsize {
for j in 0..featuresize {
let val = Normal::new(0.0, 1.0)
.expect("Operation failed")
.sample(&mut rng);
activations[[i, j]] = if val > 0.0 { val } else { 0.0 };
}
}
for _ in 0..5 {
let i = rng.random_range(0.._batchsize);
let j = rng.random_range(0..featuresize);
activations[[i, j]] = rng.random_range(2.0..5.0);
}
activations
}
#[allow(dead_code)]
fn compare_matmul_accuracy(weights: &Array2<f32>, activations: &Array2<f32>, bits: u8) {
let reference_result = activations.dot(&weights.t());
let methods = [
CalibrationMethod::MinMax,
CalibrationMethod::PercentileCalibration,
CalibrationMethod::EntropyCalibration,
CalibrationMethod::MSEOptimization,
];
println!(
"{:^25} | {:^12} | {:^12} | {:^15} | {:^15}",
"Method", "Weight MSE", "Act MSE", "MatMul MSE", "Rel Error (%)"
);
println!(
"{:-^25} | {:-^12} | {:-^12} | {:-^15} | {:-^15}",
"", "", "", "", ""
);
for &method in &methods {
let config_weights = CalibrationConfig {
method,
symmetric: true, percentile: 0.99, num_bins: 256,
..Default::default()
};
let config_activations = CalibrationConfig {
method,
symmetric: false, percentile: 0.99,
num_bins: 256,
..Default::default()
};
let weights_params =
calibrate_matrix(&weights.view(), bits, &config_weights).expect("Operation failed");
let (quantized_weights, _) = quantize_matrix(&weights.view(), bits, weights_params.method);
let dequantized_weights = dequantize_matrix(&quantized_weights, &weights_params);
let weights_mse =
(weights - &dequantized_weights).mapv(|x| x * x).sum() / weights.len() as f32;
let activations_params = calibrate_matrix(&activations.view(), bits, &config_activations)
.expect("Operation failed");
let (quantized_activations, _) =
quantize_matrix(&activations.view(), bits, activations_params.method);
let dequantized_activations =
dequantize_matrix(&quantized_activations, &activations_params);
let activations_mse = (activations - &dequantized_activations)
.mapv(|x| x * x)
.sum()
/ activations.len() as f32;
let quantized_result = match quantized_matmul(
&quantized_weights,
&weights_params,
&quantized_activations,
&activations_params,
) {
Ok(result) => result,
Err(e) => {
println!("Error in quantized matmul: {:?}", e);
dequantized_activations.dot(&dequantized_weights.t())
}
};
let matmul_mse = (&reference_result - &quantized_result)
.mapv(|x| x * x)
.sum()
/ reference_result.len() as f32;
let rel_error = (&reference_result - &quantized_result)
.mapv(|x| x.abs())
.sum()
/ reference_result.mapv(|x| x.abs()).sum()
* 100.0;
println!(
"{:^25} | {:^12.6} | {:^12.6} | {:^15.6} | {:^15.6}",
format!("{:?}", method),
weights_mse,
activations_mse,
matmul_mse,
rel_error
);
}
}
#[allow(dead_code)]
fn compare_bit_widths_matmul(weights: &Array2<f32>, activations: &Array2<f32>) {
let reference_result = activations.dot(&weights.t());
let bit_widths = [4, 8, 16];
println!(
"{:^10} | {:^15} | {:^15} | {:^15}",
"Bits", "MatMul MSE", "Rel Error (%)", "Memory Savings (%)"
);
println!("{:-^10} | {:-^15} | {:-^15} | {:-^15}", "", "", "", "");
for &bits in &bit_widths {
let config = CalibrationConfig {
method: CalibrationMethod::EntropyCalibration,
symmetric: true,
num_bins: 256,
..Default::default()
};
let weights_params =
calibrate_matrix(&weights.view(), bits, &config).expect("Operation failed");
let (quantized_weights, _) = quantize_matrix(&weights.view(), bits, weights_params.method);
let config_act = CalibrationConfig {
method: CalibrationMethod::EntropyCalibration,
symmetric: false,
num_bins: 256,
..Default::default()
};
let activations_params =
calibrate_matrix(&activations.view(), bits, &config_act).expect("Operation failed");
let (quantized_activations, _) =
quantize_matrix(&activations.view(), bits, activations_params.method);
let quantized_result = match quantized_matmul(
&quantized_weights,
&weights_params,
&quantized_activations,
&activations_params,
) {
Ok(result) => result,
Err(_) => {
let dequantized_weights = dequantize_matrix(&quantized_weights, &weights_params);
let dequantized_activations =
dequantize_matrix(&quantized_activations, &activations_params);
dequantized_activations.dot(&dequantized_weights.t())
}
};
let matmul_mse = (&reference_result - &quantized_result)
.mapv(|x| x * x)
.sum()
/ reference_result.len() as f32;
let rel_error = (&reference_result - &quantized_result)
.mapv(|x| x.abs())
.sum()
/ reference_result.mapv(|x| x.abs()).sum()
* 100.0;
let fp32size = 32;
let memory_savings = (1.0 - (bits as f32 / fp32size as f32)) * 100.0;
println!(
"{:^10} | {:^15.6} | {:^15.6} | {:^15.1}",
bits, matmul_mse, rel_error, memory_savings
);
}
}
#[allow(dead_code)]
fn demonstrate_mixed_precision(weights: &Array2<f32>, activations: &Array2<f32>) {
let reference_result = activations.dot(&weights.t());
let configs = [
(8, 8, "Standard (8-bit weights, 8-bit activations)"),
(4, 8, "Mixed (4-bit weights, 8-bit activations)"),
(8, 4, "Mixed (8-bit weights, 4-bit activations)"),
(4, 16, "Mixed (4-bit weights, 16-bit activations)"),
];
println!(
"{:^40} | {:^15} | {:^15} | {:^15}",
"Configuration", "MatMul MSE", "Rel Error (%)", "Memory Savings (%)"
);
println!("{:-^40} | {:-^15} | {:-^15} | {:-^15}", "", "", "", "");
for &(weight_bits, act_bits, desc) in &configs {
let weights_config = CalibrationConfig {
method: CalibrationMethod::EntropyCalibration,
symmetric: true,
num_bins: 256,
..Default::default()
};
let activations_config = CalibrationConfig {
method: CalibrationMethod::PercentileCalibration,
symmetric: false,
percentile: 0.995,
..Default::default()
};
let weights_params = calibrate_matrix(&weights.view(), weight_bits, &weights_config)
.expect("Operation failed");
let (quantized_weights, _) =
quantize_matrix(&weights.view(), weight_bits, weights_params.method);
let activations_params =
calibrate_matrix(&activations.view(), act_bits, &activations_config)
.expect("Operation failed");
let (quantized_activations, _) =
quantize_matrix(&activations.view(), act_bits, activations_params.method);
let dequantized_weights = dequantize_matrix(&quantized_weights, &weights_params);
let dequantized_activations =
dequantize_matrix(&quantized_activations, &activations_params);
let mixed_result = dequantized_activations.dot(&dequantized_weights.t());
let matmul_mse = (&reference_result - &mixed_result).mapv(|x| x * x).sum()
/ reference_result.len() as f32;
let rel_error = (&reference_result - &mixed_result).mapv(|x| x.abs()).sum()
/ reference_result.mapv(|x| x.abs()).sum()
* 100.0;
let fp32size = 32;
let weight_savings = 1.0 - (weight_bits as f32 / fp32size as f32);
let act_savings = 1.0 - (act_bits as f32 / fp32size as f32);
let memory_savings = (weight_savings * 0.75 + act_savings * 0.25) * 100.0;
println!(
"{:^40} | {:^15.6} | {:^15.6} | {:^15.1}",
desc, matmul_mse, rel_error, memory_savings
);
}
}