use scirs2_core::ndarray::{Array1, Array2};
use scirs2_linalg::quantization::{
dequantize_matrix,
quantize_matrix,
quantized_matmul,
QuantizationMethod,
QuantizedData2D,
};
#[allow(dead_code)]
fn main() {
println!("Quantization-aware Linear Algebra Example");
println!("=======================================\n");
println!("4-bit Quantization Demonstration");
println!("-------------------------------\n");
let int4_test =
Array2::from_shape_vec((2, 4), vec![1.0_f32, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0])
.expect("Operation failed");
println!("Original Matrix:");
println!("{:?}\n", int4_test);
let (int4_quantized, int4_params) =
quantize_matrix(&int4_test.view(), 4, QuantizationMethod::Int4);
println!("Storage comparison:");
println!(
"Original size: {} bytes",
int4_test.len() * std::mem::size_of::<f32>()
);
println!(
"Int4 size: {} bytes\n",
int4_quantized.data.len() * std::mem::size_of::<i8>()
);
println!("Int4 Quantization Parameters:");
println!(" Bits: {}", int4_params.bits);
println!(" Scale: {}", int4_params.scale);
println!(" Zero point: {}", int4_params.zero_point);
println!(" Min value: {}", int4_params.min_val);
println!(" Max value: {}\n", int4_params.max_val);
println!("Int4 Data (packed, 2 values per byte):");
if let QuantizedData2D::Int8(data) = &int4_quantized.data {
for row in 0..data.nrows() {
print!(" ");
for col in 0..data.ncols() {
print!("{:02x} ", data[[row, col]] as u8);
}
println!();
}
}
println!();
println!("Decoded Int4 Values:");
for row in 0..int4_test.nrows() {
print!(" ");
for col in 0..int4_test.ncols() {
print!("{:2} ", int4_quantized.get_i8(row, col));
}
println!();
}
println!();
let int4_dequantized = dequantize_matrix(&int4_quantized, &int4_params);
println!("Dequantized Matrix:");
println!("{:?}\n", int4_dequantized);
println!("Quantization Error:");
println!("{:?}\n", &int4_test - &int4_dequantized);
let uint4_test =
Array2::from_shape_vec((2, 4), vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("Operation failed");
println!("\nUInt4 Quantization Example");
println!("-------------------------\n");
println!("Original Matrix (positive values):");
println!("{:?}\n", uint4_test);
let (uint4_quantized, uint4_params) =
quantize_matrix(&uint4_test.view(), 4, QuantizationMethod::UInt4);
println!("UInt4 Quantization Parameters:");
println!(" Bits: {}", uint4_params.bits);
println!(" Scale: {}", uint4_params.scale);
println!(" Zero point: {}", uint4_params.zero_point);
println!(" Min value: {}", uint4_params.min_val);
println!(" Max value: {}\n", uint4_params.max_val);
println!("UInt4 Data (packed, 2 values per byte):");
if let QuantizedData2D::Int8(data) = &uint4_quantized.data {
for row in 0..data.nrows() {
print!(" ");
for col in 0..data.ncols() {
print!("{:02x} ", data[[row, col]] as u8);
}
println!();
}
}
println!();
println!("Decoded UInt4 Values:");
for row in 0..uint4_test.nrows() {
print!(" ");
for col in 0..uint4_test.ncols() {
print!("{:2} ", uint4_quantized.get_i8(row, col));
}
println!();
}
println!();
let uint4_dequantized = dequantize_matrix(&uint4_quantized, &uint4_params);
println!("Dequantized Matrix:");
println!("{:?}\n", uint4_dequantized);
println!("Quantization Error:");
println!("{:?}\n", &uint4_test - &uint4_dequantized);
println!("Standard 8-bit Quantization Examples");
println!("----------------------------------\n");
let a = Array2::from_shape_vec((3, 3), vec![1.2, 2.5, 3.7, 4.2, 5.0, 6.1, 7.3, 8.4, 9.5])
.expect("Operation failed");
let b = Array2::from_shape_vec((3, 2), vec![0.5, 1.5, 2.5, 3.5, 4.5, 5.5])
.expect("Operation failed");
let x = Array1::from_shape_vec(3, vec![0.1, 0.2, 0.3]).expect("Operation failed");
println!("Original Matrix A:");
println!("{:?}\n", a);
println!("Original Matrix B:");
println!("{:?}\n", b);
println!("Original Vector x:");
println!("{:?}\n", x);
println!("Basic Quantization-Dequantization");
println!("--------------------------------");
let (a_q_uniform, a_params_uniform) =
quantize_matrix(&a.view(), 8, QuantizationMethod::Uniform);
let a_dequant_uniform = dequantize_matrix(&a_q_uniform, &a_params_uniform);
println!("Uniform Quantization Parameters:");
println!(" Bits: {}", a_params_uniform.bits);
println!(" Scale: {}", a_params_uniform.scale);
println!(" Zero point: {}", a_params_uniform.zero_point);
println!(" Min value: {}", a_params_uniform.min_val);
println!(" Max value: {}", a_params_uniform.max_val);
println!();
println!("Quantized Matrix A (Uniform, 8-bit):");
if let QuantizedData2D::Int8(data) = &a_q_uniform.data {
println!("{:?}\n", data);
}
println!("Dequantized Matrix A (Uniform):");
println!("{:?}\n", a_dequant_uniform);
println!("Quantization Error (Uniform):");
println!("{:?}\n", &a - &a_dequant_uniform);
let (a_q_symmetric, a_params_symmetric) =
quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
let a_dequant_symmetric = dequantize_matrix(&a_q_symmetric, &a_params_symmetric);
println!("Symmetric Quantization Parameters:");
println!(" Bits: {}", a_params_symmetric.bits);
println!(" Scale: {}", a_params_symmetric.scale);
println!(" Zero point: {}", a_params_symmetric.zero_point);
println!(" Min value: {}", a_params_symmetric.min_val);
println!(" Max value: {}", a_params_symmetric.max_val);
println!();
println!("Quantized Matrix A (Symmetric, 8-bit):");
if let QuantizedData2D::Int8(data) = &a_q_symmetric.data {
println!("{:?}\n", data);
}
println!("Dequantized Matrix A (Symmetric):");
println!("{:?}\n", a_dequant_symmetric);
println!("Quantization Error (Symmetric):");
println!("{:?}\n", &a - &a_dequant_symmetric);
let (a_q_affine, a_params_affine) = quantize_matrix(&a.view(), 8, QuantizationMethod::Affine);
let a_dequant_affine = dequantize_matrix(&a_q_affine, &a_params_affine);
println!("Affine Quantization Parameters:");
println!(" Bits: {}", a_params_affine.bits);
println!(" Scale: {}", a_params_affine.scale);
println!(" Zero point: {}", a_params_affine.zero_point);
println!(" Min value: {}", a_params_affine.min_val);
println!(" Max value: {}", a_params_affine.max_val);
println!();
println!("Quantized Matrix A (Affine, 8-bit):");
if let QuantizedData2D::Int8(data) = &a_q_affine.data {
println!("{:?}\n", data);
}
println!("Dequantized Matrix A (Affine):");
println!("{:?}\n", a_dequant_affine);
println!("Quantization Error (Affine):");
println!("{:?}\n", &a - &a_dequant_affine);
println!("Quantized Matrix Operations");
println!("--------------------------");
println!("Regular Matrix Multiplication A * B:");
let c = a.dot(&b);
println!("{:?}\n", c);
let (b_q, b_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
println!("Quantized Matrix Multiplication (8-bit):");
let c_q = quantized_matmul(&a_q_symmetric, &a_params_symmetric, &b_q, &b_params)
.expect("Operation failed");
println!("{:?}\n", c_q);
println!("Quantization Error for Matrix Multiplication:");
println!("{:?}\n", &c - &c_q);
let rel_error = (&c - &c_q).mapv(|x| x.abs()).sum() / c.sum();
println!(
"Relative Error for Matrix Multiplication: {:.6}\n",
rel_error
);
println!("Quantization Comparison");
println!("----------------------");
let methods = [
QuantizationMethod::Uniform,
QuantizationMethod::Symmetric,
QuantizationMethod::Affine,
QuantizationMethod::PowerOfTwo,
QuantizationMethod::Int4,
QuantizationMethod::UInt4,
QuantizationMethod::Float16,
QuantizationMethod::BFloat16,
];
let bits_list = [4, 6, 8, 10, 16];
println!("Mean Squared Error (MSE) for different methods and bit widths:");
println!();
print!("{:<12}", "Bits");
for method in &methods {
let method_name = match method {
QuantizationMethod::Uniform => "Uniform",
QuantizationMethod::Symmetric => "Symmetric",
QuantizationMethod::Affine => "Affine",
QuantizationMethod::PowerOfTwo => "PowerOfTwo",
QuantizationMethod::Int4 => "Int4",
QuantizationMethod::UInt4 => "UInt4",
QuantizationMethod::Float16 => "Float16",
QuantizationMethod::BFloat16 => "BFloat16",
QuantizationMethod::PerChannelSymmetric => "PC-Symmetric",
QuantizationMethod::PerChannelAffine => "PC-Affine",
};
print!("{:<12}", method_name);
}
println!();
for &bits in &bits_list {
print!("{:<12}", bits);
for &method in &methods {
let effective_bits = match method {
QuantizationMethod::Int4 | QuantizationMethod::UInt4 => 4,
QuantizationMethod::Float16 | QuantizationMethod::BFloat16 => 16,
_ => bits,
};
print!("{:<12}", "N/A");
}
println!();
}
println!();
println!("16-bit Floating-Point Quantization");
println!("----------------------------------");
let wide_range = Array1::from_shape_vec(
8,
vec![
0.000001, 123456.0, -0.000002, -98765.0, std::f32::consts::PI, std::f32::consts::E, 0.0, 1.0, ],
)
.expect("Operation failed");
println!("Original values with wide dynamic range:");
println!("{:?}\n", wide_range);
println!("Float16 Matrix Operations");
println!("-----------------------");
let a_for_f16 = Array2::from_shape_vec((2, 3), vec![1.1, 2.2, 3.3, 4.4, 5.5, 6.6])
.expect("Operation failed");
let b_for_f16 = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
.expect("Operation failed");
let (a_f16, a_f16_params) = quantize_matrix(&a_for_f16.view(), 16, QuantizationMethod::Float16);
let (b_f16, b_f16_params) = quantize_matrix(&b_for_f16.view(), 16, QuantizationMethod::Float16);
let c_full = a_for_f16.dot(&b_for_f16);
println!("Regular matrix multiplication result (f32):");
println!("{:?}\n", c_full);
let c_f16 =
quantized_matmul(&a_f16, &a_f16_params, &b_f16, &b_f16_params).expect("Operation failed");
println!("Float16 matrix multiplication result:");
println!("{:?}\n", c_f16);
let rel_error_f16 = (&c_full - &c_f16).mapv(|x| x.abs()).sum() / c_full.sum();
println!("Float16 matmul relative error: {:.6e}\n", rel_error_f16);
let matrixsize = 100;
let largematrix = Array2::from_elem((matrixsize, matrixsize), 1.0f32);
let originalsize = matrixsize * matrixsize * std::mem::size_of::<f32>();
let (int8_large, _) = quantize_matrix(&largematrix.view(), 8, QuantizationMethod::Symmetric);
let (int4_large, _) = quantize_matrix(&largematrix.view(), 4, QuantizationMethod::Int4);
let (f16_large, _) = quantize_matrix(&largematrix.view(), 16, QuantizationMethod::Float16);
let (bf16_large, _) = quantize_matrix(&largematrix.view(), 16, QuantizationMethod::BFloat16);
let int8size = int8_large.data.len() * std::mem::size_of::<i8>();
let int4size = int4_large.data.len() * std::mem::size_of::<i8>(); let f16size = f16_large.data.len() * 2; let bf16size = bf16_large.data.len() * 2;
println!(
"Storage Efficiency Comparison ({}x{} matrix):",
matrixsize, matrixsize
);
println!(" Original f32: {} bytes (100.0%)", originalsize);
println!(
" Int8: {} bytes ({:.1}%)",
int8size,
100.0 * int8size as f32 / originalsize as f32
);
println!(
" Int4: {} bytes ({:.1}%)",
int4size,
100.0 * int4size as f32 / originalsize as f32
);
println!(
" Float16: {} bytes ({:.1}%)",
f16size,
100.0 * f16size as f32 / originalsize as f32
);
println!(
" BFloat16: {} bytes ({:.1}%)",
bf16size,
100.0 * bf16size as f32 / originalsize as f32
);
}