use crate::error::{LinalgError, LinalgResult};
use crate::quantization::{
dequantize_matrix, get_quantizedmatrix_2d_i8, QuantizationMethod, QuantizationParams,
QuantizedData2D, QuantizedMatrix,
};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use std::fmt::Debug;
#[allow(dead_code)]
pub fn fused_quantized_matmul_chain(
matrices: &[&QuantizedMatrix],
params: &[&QuantizationParams],
) -> LinalgResult<Array2<f32>> {
if matrices.len() < 2 {
return Err(LinalgError::ShapeError(
"At least two matrices are required for matmul chain".to_string(),
));
}
if matrices.len() != params.len() {
return Err(LinalgError::ShapeError(
"Number of matrices must match number of quantization parameters".to_string(),
));
}
for i in 0..matrices.len() - 1 {
if matrices[i].shape.1 != matrices[i + 1].shape.0 {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions mismatch at position {}: ({}, {}) * ({}, {})",
i,
matrices[i].shape.0,
matrices[i].shape.1,
matrices[i + 1].shape.0,
matrices[i + 1].shape.1
)));
}
}
let all_int8 = matrices
.iter()
.all(|m| matches!(m.data, QuantizedData2D::Int8(_)));
let all_symmetric = params
.iter()
.all(|p| p.method == QuantizationMethod::Symmetric || p.method == QuantizationMethod::Int4);
if all_int8 && all_symmetric {
fused_quantized_matmul_chain_int8_symmetric(matrices, params)
} else {
let mut dequantized_matrices = Vec::with_capacity(matrices.len());
for (matrix, param) in matrices.iter().zip(params.iter()) {
dequantized_matrices.push(dequantize_matrix(matrix, param));
}
let mut result = dequantized_matrices[0].clone();
for mat in dequantized_matrices.iter().skip(1) {
result = result.dot(mat);
}
Ok(result)
}
}
#[allow(dead_code)]
fn fused_quantized_matmul_chain_int8_symmetric(
matrices: &[&QuantizedMatrix],
params: &[&QuantizationParams],
) -> LinalgResult<Array2<f32>> {
let int8_matrices: Vec<&Array2<i8>> = matrices
.iter()
.map(|m| get_quantizedmatrix_2d_i8(m).expect("Operation failed"))
.collect();
let scales: Vec<f32> = params.iter().map(|p| p.scale).collect();
let rows_ = matrices[0].shape.0;
let cols = matrices.last().expect("Operation failed").shape.1;
let mut result = Array2::zeros((rows_, cols));
let fused_scale: f32 = scales.iter().product();
const BLOCK_SIZE: usize = 32;
for i0 in (0..rows_).step_by(BLOCK_SIZE) {
let i_end = (i0 + BLOCK_SIZE).min(rows_);
for j0 in (0..cols).step_by(BLOCK_SIZE) {
let j_end = (j0 + BLOCK_SIZE).min(cols);
for i in i0..i_end {
for j in j0..j_end {
let mut middle_dim = matrices[0].shape.1;
let mut intermediate = vec![0i32; middle_dim];
for (k, val) in intermediate.iter_mut().enumerate().take(middle_dim) {
*val = int8_matrices[0][[i, k]] as i32;
}
for mat_idx in 1..matrices.len() - 1 {
let mat = int8_matrices[mat_idx];
let (_, inner_dim) = matrices[mat_idx].shape;
let mut new_intermediate = vec![0i32; inner_dim];
for l in 0..inner_dim {
for k in 0..middle_dim {
new_intermediate[l] += intermediate[k] * (mat[[k, l]] as i32);
}
}
intermediate = new_intermediate;
middle_dim = inner_dim;
}
let last_mat = int8_matrices.last().expect("Operation failed");
let mut sum = 0i32;
for k in 0..middle_dim {
sum += intermediate[k] * (last_mat[[k, j]] as i32);
}
result[[i, j]] = (sum as f32) * fused_scale;
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn fused_quantized_matvec_sequence<F>(
matrices: &[&QuantizedMatrix],
matrix_params: &[&QuantizationParams],
vector: &ArrayView1<F>,
output_quantize: bool,
) -> LinalgResult<Array1<F>>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
if matrices.is_empty() {
return Err(LinalgError::ShapeError(
"At least one matrix is required for matvec sequence".to_string(),
));
}
if matrices.len() != matrix_params.len() {
return Err(LinalgError::ShapeError(
"Number of matrices must match number of quantization parameters".to_string(),
));
}
let vector_len = vector.len();
if matrices.last().expect("Operation failed").shape.1 != vector_len {
return Err(LinalgError::ShapeError(format!(
"Last matrix columns ({}) must match vector length ({})",
matrices.last().expect("Operation failed").shape.1,
vector_len
)));
}
for i in 0..matrices.len() - 1 {
if matrices[i].shape.1 != matrices[i + 1].shape.0 {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions mismatch at position {}: ({}, {}) * ({}, {})",
i,
matrices[i].shape.0,
matrices[i].shape.1,
matrices[i + 1].shape.0,
matrices[i + 1].shape.1
)));
}
}
let all_int8 = matrices
.iter()
.all(|m| matches!(m.data, QuantizedData2D::Int8(_)));
if all_int8 {
let vector_f32 = vector.mapv(|x| x.as_());
let vector_f32_view = vector_f32.view();
let result_f32 = if matrices.len() == 1 {
use crate::quantization::simd::simd_quantized_matvec;
simd_quantized_matvec(matrices[0], matrix_params[0], &vector_f32_view)?
} else {
fused_quantized_matvec_sequence_int8(matrices, matrix_params, &vector_f32_view)?
};
if output_quantize {
Ok(result_f32.mapv(|x| {
scirs2_core::numeric::FromPrimitive::from_f32(x).expect("Operation failed")
}))
} else {
Ok(result_f32.mapv(|x| {
scirs2_core::numeric::FromPrimitive::from_f32(x).expect("Operation failed")
}))
}
} else {
let mut dequantized_matrices = Vec::with_capacity(matrices.len());
for (matrix, param) in matrices.iter().zip(matrix_params.iter()) {
dequantized_matrices.push(dequantize_matrix(matrix, param));
}
let vector_f32 = vector.mapv(|x| x.as_());
let mut result_f32 = vector_f32.insert_axis(scirs2_core::ndarray::Axis(1));
for mat in dequantized_matrices.iter().rev() {
result_f32 = mat.dot(&result_f32);
}
let result_1d_f32 = result_f32.remove_axis(scirs2_core::ndarray::Axis(1));
let result_f = result_1d_f32
.mapv(|x| scirs2_core::numeric::FromPrimitive::from_f32(x).expect("Operation failed"));
Ok(result_f)
}
}
#[allow(dead_code)]
fn fused_quantized_matvec_sequence_int8(
matrices: &[&QuantizedMatrix],
params: &[&QuantizationParams],
vector: &ArrayView1<f32>,
) -> LinalgResult<Array1<f32>> {
let int8_matrices: Vec<&Array2<i8>> = matrices
.iter()
.map(|m| get_quantizedmatrix_2d_i8(m).expect("Operation failed"))
.collect();
let scales: Vec<f32> = params.iter().map(|p| p.scale).collect();
let _zero_points: Vec<i32> = params.iter().map(|p| p.zero_point).collect();
let symmetric = params
.iter()
.all(|p| p.method == QuantizationMethod::Symmetric);
let output_dim = matrices[0].shape.0;
let mut result = Array1::zeros(output_dim);
if symmetric {
let fused_scale: f32 = scales.iter().product();
for i in 0..output_dim {
let row = int8_matrices[0].row(i);
let middle_dim = matrices[0].shape.1;
let mut intermediate = vec![0i32; middle_dim];
for k in 0..middle_dim {
intermediate[k] = row[k] as i32;
}
for mat_idx in 1..matrices.len() {
let mat = int8_matrices[mat_idx];
let (rows, cols) = matrices[mat_idx].shape;
let mut new_intermediate = vec![0i32; cols];
for c in 0..cols {
for r in 0..rows {
new_intermediate[c] += intermediate[r] * (mat[[r, c]] as i32);
}
}
intermediate = new_intermediate;
}
let mut sum = 0.0;
for k in 0..intermediate.len() {
sum += (intermediate[k] as f32) * vector[k];
}
result[i] = sum * fused_scale;
}
} else {
let mut dequantized_matrices = Vec::with_capacity(matrices.len());
for (matrix, param) in matrices.iter().zip(params.iter()) {
dequantized_matrices.push(dequantize_matrix(matrix, param));
}
let vector_2d = vector.to_owned().insert_axis(scirs2_core::ndarray::Axis(1));
let mut result_2d = vector_2d;
for mat in dequantized_matrices.iter().rev() {
result_2d = mat.dot(&result_2d);
}
result = result_2d.remove_axis(scirs2_core::ndarray::Axis(1));
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quantization::{quantize_matrix, QuantizationMethod};
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_fused_matmul_chain() {
let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
let c = array![[13.0f32, 14.0, 15.0], [16.0, 17.0, 18.0]];
let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
let (qc, qc_params) = quantize_matrix(&c.view(), 8, QuantizationMethod::Symmetric);
let ab = a.dot(&b);
let expected = ab.dot(&c);
let matrices = [&qa, &qb, &qc];
let params = [&qa_params, &qb_params, &qc_params];
let result = fused_quantized_matmul_chain(&matrices, ¶ms).expect("Operation failed");
assert_eq!(result.shape(), expected.shape());
for ((i, j), &val) in result.indexed_iter() {
assert_relative_eq!(val, expected[[i, j]], epsilon = 12.0);
}
}
#[test]
fn test_fused_matvec_sequence() {
let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
let x = array![13.0f32, 14.0];
let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
let bx = b.dot(&x);
let expected = a.dot(&bx);
let matrices = [&qa, &qb];
let params = [&qa_params, &qb_params];
let result = fused_quantized_matvec_sequence(&matrices, ¶ms, &x.view(), false)
.expect("Operation failed");
assert_eq!(result.len(), expected.len());
for (i, &val) in result.iter().enumerate() {
assert_relative_eq!(val, expected[i], epsilon = 5.0);
}
}
}