use scirs2_core::ndarray::{Array1, Array2};
use scirs2_linalg::quantization::simd::{simd_quantized_matmul, simd_quantized_matvec};
use scirs2_linalg::quantization::{quantize_matrix, QuantizationMethod};
use std::time::Instant;
#[allow(dead_code)]
fn main() {
println!("SIMD-Accelerated Quantized Matrix Operations Example");
println!("===================================================\n");
let m = 100;
let k = 100;
let n = 100;
println!("Generating random matrices...");
let a = generate_randommatrix(m, k);
let b = generate_randommatrix(k, n);
let v = generate_random_vector(k);
println!("\nPerforming matrix-matrix multiplication tests...");
let start = Instant::now();
let c_ref = a.dot(&b);
let ref_time = start.elapsed();
println!("Reference matmul time: {:?}", ref_time);
let (a_q, a_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
let (b_q, b_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
let start = Instant::now();
let c_q_simd =
simd_quantized_matmul(&a_q, &a_params, &b_q, &b_params).expect("Operation failed");
let simd_time = start.elapsed();
println!("SIMD quantized matmul time: {:?}", simd_time);
let max_error = (&c_ref - &c_q_simd)
.mapv(|x| x.abs())
.fold(0.0_f32, |acc, &x| acc.max(x));
let rel_error = max_error / c_ref.fold(0.0_f32, |acc, &x| acc.max(x.abs()));
println!("Max absolute error: {}", max_error);
println!("Relative error: {}", rel_error);
println!(
"Speedup: {:.2}x",
ref_time.as_secs_f64() / simd_time.as_secs_f64()
);
println!("\nPerforming matrix-vector multiplication tests...");
let start = Instant::now();
let r_ref = a.dot(&v);
let ref_time = start.elapsed();
println!("Reference matvec time: {:?}", ref_time);
let start = Instant::now();
let r_q_simd = simd_quantized_matvec(&a_q, &a_params, &v.view()).expect("Operation failed");
let simd_time = start.elapsed();
println!("SIMD quantized matvec time: {:?}", simd_time);
let max_error = (&r_ref - &r_q_simd)
.mapv(|x| x.abs())
.fold(0.0_f32, |acc, &x| acc.max(x));
let rel_error = max_error / r_ref.fold(0.0_f32, |acc, &x| acc.max(x.abs()));
println!("Max absolute error: {}", max_error);
println!("Relative error: {}", rel_error);
println!(
"Speedup: {:.2}x",
ref_time.as_secs_f64() / simd_time.as_secs_f64()
);
println!("\nBenchmarking different quantization methods for matrix multiplication:");
benchmark_quantization_methods(&a, &b);
}
#[allow(dead_code)]
fn generate_randommatrix(rows: usize, cols: usize) -> Array2<f32> {
let mut matrix = Array2::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
matrix[[i, j]] = scirs2_core::random::random::<f32>() * 2.0 - 1.0; }
}
matrix
}
#[allow(dead_code)]
fn generate_random_vector(length: usize) -> Array1<f32> {
let mut vector = Array1::zeros(length);
for i in 0..length {
vector[i] = scirs2_core::random::random::<f32>() * 2.0 - 1.0; }
vector
}
#[allow(dead_code)]
fn benchmark_quantization_methods(a: &Array2<f32>, b: &Array2<f32>) {
let methods = [
QuantizationMethod::Symmetric,
QuantizationMethod::Affine,
QuantizationMethod::Int4,
QuantizationMethod::UInt4,
QuantizationMethod::PerChannelSymmetric,
];
let bit_widths = [8, 4];
println!(
"{:^15} | {:^10} | {:^15} | {:^15} | {:^10}",
"Method", "Bits", "Time (ms)", "Rel. Error", "Speedup"
);
println!(
"{:-^15} | {:-^10} | {:-^15} | {:-^15} | {:-^10}",
"", "", "", "", ""
);
let start = Instant::now();
let c_ref = a.dot(b);
let ref_time = start.elapsed();
println!(
"{:^15} | {:^10} | {:^15.3} | {:^15} | {:^10}",
"Reference",
"32",
ref_time.as_millis() as f64,
0.0,
1.0
);
let max_abs_val = c_ref.fold(0.0_f32, |acc, &x| acc.max(x.abs()));
for &method in &methods {
for &bits in &bit_widths {
if bits != 4
&& (method == QuantizationMethod::Int4 || method == QuantizationMethod::UInt4)
{
continue;
}
if bits != 8 && method == QuantizationMethod::PerChannelSymmetric {
continue;
}
let (a_q, a_params) = quantize_matrix(&a.view(), bits, method);
let (b_q, b_params) = quantize_matrix(&b.view(), bits, method);
let start = Instant::now();
let c_q =
simd_quantized_matmul(&a_q, &a_params, &b_q, &b_params).expect("Operation failed");
let q_time = start.elapsed();
let abs_error = (&c_ref - &c_q)
.mapv(|x| x.abs())
.fold(0.0_f32, |acc, &x| acc.max(x));
let rel_error = abs_error / max_abs_val;
let speedup = ref_time.as_secs_f64() / q_time.as_secs_f64();
println!(
"{:^15} | {:^10} | {:^15.3} | {:^15.6} | {:^10.2}",
format!("{:?}", method),
bits,
q_time.as_millis() as f64,
rel_error,
speedup
);
}
}
}