use half::{bf16, f16};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{AsPrimitive, Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{LinalgError, LinalgResult};
use super::matrix::{QuantizedData2D, QuantizedMatrix};
use super::types::{QuantizationMethod, QuantizationParams, QuantizedDataType};
use super::vector::{QuantizedData1D, QuantizedVector};
#[allow(dead_code)]
pub fn quantize_matrix<F>(
matrix: &ArrayView2<F>,
bits: u8,
method: QuantizationMethod,
) -> (QuantizedMatrix, QuantizationParams)
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let shape = (matrix.nrows(), matrix.ncols());
let mut min_val = F::infinity().as_();
let mut max_val = F::neg_infinity().as_();
for &val in matrix.iter() {
let val_f32: f32 = val.as_();
if val_f32.is_finite() {
min_val = min_val.min(val_f32);
max_val = max_val.max(val_f32);
}
}
if (max_val - min_val).abs() < f32::EPSILON {
max_val = min_val + 1.0;
}
if method == QuantizationMethod::Float16 {
let mut f16_data = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
f16_data.as_slice_mut().expect("Operation failed")[i] = f16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::Float16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedMatrix::new_f16(f16_data, shape), params);
}
if method == QuantizationMethod::BFloat16 {
let mut bf16_data = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
bf16_data.as_slice_mut().expect("Operation failed")[i] = bf16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::BFloat16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedMatrix::new_bf16(bf16_data, shape), params);
}
let data_type = match method {
QuantizationMethod::Int4 => QuantizedDataType::Int4,
QuantizationMethod::UInt4 => QuantizedDataType::UInt4,
_ => QuantizedDataType::Int8,
};
let effective_bits = match method {
QuantizationMethod::Int4 | QuantizationMethod::UInt4 => 4,
_ => bits,
};
let (scale, zero_point) = match method {
QuantizationMethod::Uniform => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Symmetric => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / ((1 << (effective_bits - 1)) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Affine => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
QuantizationMethod::PowerOfTwo => {
let range = max_val - min_val;
let ideal_scale = range / ((1 << effective_bits) - 1) as f32;
let exponent = ideal_scale.log2().ceil();
let scale = 2.0_f32.powf(exponent);
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Int4 => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / 7.0; let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::UInt4 => {
let scale = (max_val - min_val) / 15.0; let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
_ => unreachable!(), };
let params = QuantizationParams {
bits: effective_bits,
scale,
zero_point,
min_val,
max_val,
method,
data_type,
channel_scales: None,
channel_zero_points: None,
};
match method {
QuantizationMethod::Int4 => {
let num_elements = matrix.len();
let mut packed_data = Array2::zeros((shape.0, shape.1.div_ceil(2)));
for i in 0..num_elements {
let val_f32: f32 = matrix.as_slice().expect("Operation failed")[i].as_();
let q_val = ((val_f32 / scale).round() as i8).clamp(-8, 7);
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] = q_val << 4;
} else {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] |= q_val & 0x0F;
}
}
let packedshape = (shape.0, shape.1.div_ceil(2));
let packed_reshaped = packed_data
.into_shape_with_order(packedshape)
.expect("Operation failed");
(
QuantizedMatrix::new_i8(packed_reshaped, shape, QuantizedDataType::Int4),
params,
)
}
QuantizationMethod::UInt4 => {
let num_elements = matrix.len();
let mut packed_data = Array2::zeros((shape.0, shape.1.div_ceil(2)));
for i in 0..num_elements {
let val_f32: f32 = matrix.as_slice().expect("Operation failed")[i].as_();
let ival = ((val_f32 - min_val) / scale).round() as i32;
let q_val = (ival.clamp(0, 15) & 0x0F) as i8;
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] = q_val << 4;
} else {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] |= q_val & 0x0F;
}
}
let packedshape = (shape.0, shape.1.div_ceil(2));
let packed_reshaped = packed_data
.into_shape_with_order(packedshape)
.expect("Operation failed");
(
QuantizedMatrix::new_i8(packed_reshaped, shape, QuantizedDataType::UInt4),
params,
)
}
_ => {
let quantized_data = match method {
QuantizationMethod::Uniform => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
QuantizationMethod::Symmetric => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = (val_f32 / scale).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
QuantizationMethod::Affine => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 / scale) + zero_point as f32).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
QuantizationMethod::PowerOfTwo => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
_ => unreachable!(), };
(
QuantizedMatrix::new_i8(quantized_data, shape, QuantizedDataType::Int8),
params,
)
}
}
}
#[allow(dead_code)]
pub fn quantize_matrix_per_channel<F>(
matrix: &ArrayView2<F>,
bits: u8,
method: QuantizationMethod,
) -> (QuantizedMatrix, QuantizationParams)
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
assert!(
method == QuantizationMethod::PerChannelSymmetric
|| method == QuantizationMethod::PerChannelAffine,
"quantize_matrix_per_channel requires PerChannelSymmetric or PerChannelAffine method, got {method:?}"
);
let shape = (matrix.nrows(), matrix.ncols());
let num_channels = shape.1;
let data_type = QuantizedDataType::Int8.clone();
let mut channel_min_vals = vec![F::infinity().as_(); num_channels];
let mut channel_max_vals = vec![F::neg_infinity().as_(); num_channels];
for col in 0..num_channels {
for row in 0..shape.0 {
let val_f32: f32 = matrix[[row, col]].as_();
if val_f32.is_finite() {
channel_min_vals[col] = channel_min_vals[col].min(val_f32);
channel_max_vals[col] = channel_max_vals[col].max(val_f32);
}
}
if (channel_max_vals[col] - channel_min_vals[col]).abs() < f32::EPSILON {
channel_max_vals[col] = channel_min_vals[col] + 1.0;
}
}
let min_val = channel_min_vals
.iter()
.fold(F::infinity().as_(), |acc, &val| acc.min(val));
let max_val = channel_max_vals
.iter()
.fold(F::neg_infinity().as_(), |acc, &val| acc.max(val));
let mut channel_scales = vec![0.0; num_channels];
let mut channel_zero_points = vec![0; num_channels];
match method {
QuantizationMethod::PerChannelSymmetric => {
for col in 0..num_channels {
let abs_max = channel_max_vals[col].abs().max(channel_min_vals[col].abs());
channel_scales[col] = abs_max / ((1 << (bits - 1)) - 1) as f32;
channel_zero_points[col] = 0; }
}
QuantizationMethod::PerChannelAffine => {
for col in 0..num_channels {
channel_scales[col] =
(channel_max_vals[col] - channel_min_vals[col]) / ((1 << bits) - 1) as f32;
channel_zero_points[col] =
(-channel_min_vals[col] / channel_scales[col]).round() as i32;
}
}
_ => unreachable!(),
}
let scale = channel_scales.iter().sum::<f32>() / num_channels as f32;
let zero_point = if method == QuantizationMethod::PerChannelAffine {
(channel_zero_points.iter().sum::<i32>() as f32 / num_channels as f32).round() as i32
} else {
0
};
let params = QuantizationParams {
bits,
scale,
zero_point,
min_val,
max_val,
method,
data_type: data_type.clone(),
channel_scales: Some(channel_scales.clone()),
channel_zero_points: Some(channel_zero_points.clone()),
};
let mut quantized_data = Array2::zeros(shape);
for col in 0..num_channels {
let scale = channel_scales[col];
let zero_point = channel_zero_points[col];
for row in 0..shape.0 {
let val_f32: f32 = matrix[[row, col]].as_();
let q_val = match method {
QuantizationMethod::PerChannelSymmetric => {
(val_f32 / scale)
.round()
.clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
as i8
}
QuantizationMethod::PerChannelAffine => {
((val_f32 / scale) + zero_point as f32)
.round()
.clamp(0.0, ((1 << bits) - 1) as f32) as i8
}
_ => unreachable!(),
};
quantized_data[[row, col]] = q_val;
}
}
(
QuantizedMatrix::new_i8(quantized_data, shape, data_type.clone()),
params,
)
}
#[allow(dead_code)]
pub fn dequantize_matrix(quantized: &QuantizedMatrix, params: &QuantizationParams) -> Array2<f32> {
let shape = quantized.shape();
let mut dequantized = Array2::zeros(shape);
match &quantized.data {
QuantizedData2D::Float16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized.as_slice_mut().expect("Operation failed")[i] = val.to_f32();
}
}
QuantizedData2D::BFloat16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized.as_slice_mut().expect("Operation failed")[i] = val.to_f32();
}
}
QuantizedData2D::Int8(data) => {
match quantized.data_type {
QuantizedDataType::Int4 | QuantizedDataType::UInt4 => {
let num_elements = shape.0 * shape.1;
for i in 0..num_elements {
let row = i / shape.1;
let col = i % shape.1;
let q_val = quantized.get_i8(row, col);
let val = match params.method {
QuantizationMethod::Int4 => q_val as f32 * params.scale,
QuantizationMethod::UInt4 => {
params.min_val + (q_val as f32 * params.scale)
}
_ => unreachable!(), };
dequantized[[row, col]] = val;
}
}
QuantizedDataType::Int8
if params.method == QuantizationMethod::PerChannelSymmetric
|| params.method == QuantizationMethod::PerChannelAffine =>
{
let channel_scales = params
.channel_scales
.as_ref()
.expect("Per-channel quantization requires channel_scales");
let channel_zero_points = params
.channel_zero_points
.as_ref()
.expect("Per-channel quantization requires channel_zero_points");
let num_channels = shape.1;
for row in 0..shape.0 {
for col in 0..num_channels {
let q_val = data[[row, col]];
let scale = channel_scales[col];
let zero_point = channel_zero_points[col];
let val = match params.method {
QuantizationMethod::PerChannelSymmetric => {
q_val as f32 * scale
}
QuantizationMethod::PerChannelAffine => {
scale * (q_val as f32 - zero_point as f32)
}
_ => unreachable!(), };
dequantized[[row, col]] = val;
}
}
}
QuantizedDataType::Int8 => {
match params.method {
QuantizationMethod::Uniform => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
QuantizationMethod::Symmetric => {
for (i, &q_val) in data.iter().enumerate() {
let val = q_val as f32 * params.scale;
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
QuantizationMethod::Affine => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.scale * (q_val as f32 - params.zero_point as f32);
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
QuantizationMethod::PowerOfTwo => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
_ => unreachable!(), }
}
_ => unreachable!(), }
}
}
dequantized
}
pub fn quantize_vector<F>(
vector: &ArrayView1<F>,
bits: u8,
method: QuantizationMethod,
) -> (QuantizedVector, QuantizationParams)
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let length = vector.len();
let mut min_val = F::infinity().as_();
let mut max_val = F::neg_infinity().as_();
for &val in vector.iter() {
let val_f32: f32 = val.as_();
if val_f32.is_finite() {
min_val = min_val.min(val_f32);
max_val = max_val.max(val_f32);
}
}
if (max_val - min_val).abs() < f32::EPSILON {
max_val = min_val + 1.0;
}
if method == QuantizationMethod::Float16 {
let mut f16_data = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
f16_data[i] = f16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::Float16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedVector::new_f16(f16_data, length), params);
}
if method == QuantizationMethod::BFloat16 {
let mut bf16_data = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
bf16_data[i] = bf16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::BFloat16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedVector::new_bf16(bf16_data, length), params);
}
let data_type = match method {
QuantizationMethod::Int4 => QuantizedDataType::Int4,
QuantizationMethod::UInt4 => QuantizedDataType::UInt4,
_ => QuantizedDataType::Int8,
};
let effective_bits = match method {
QuantizationMethod::Int4 | QuantizationMethod::UInt4 => 4,
_ => bits,
};
let (scale, zero_point) = match method {
QuantizationMethod::Uniform => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Symmetric => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / ((1 << (effective_bits - 1)) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Affine => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
QuantizationMethod::PowerOfTwo => {
let range = max_val - min_val;
let ideal_scale = range / ((1 << effective_bits) - 1) as f32;
let exponent = ideal_scale.log2().ceil();
let scale = 2.0_f32.powf(exponent);
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Int4 => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / 7.0; let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::UInt4 => {
let scale = (max_val - min_val) / 15.0; let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
_ => unreachable!(), };
let params = QuantizationParams {
bits: effective_bits,
scale,
zero_point,
min_val,
max_val,
method,
data_type,
channel_scales: None,
channel_zero_points: None,
};
match method {
QuantizationMethod::Int4 => {
let packedsize = length.div_ceil(2); let mut packed_data = Array1::zeros(packedsize);
for i in 0..length {
let val_f32: f32 = vector[i].as_();
let q_val = ((val_f32 / scale).round() as i8).clamp(-8, 7);
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data[byte_idx] = q_val << 4;
} else {
packed_data[byte_idx] |= q_val & 0x0F;
}
}
(
QuantizedVector::new_i8(packed_data, length, QuantizedDataType::Int4),
params,
)
}
QuantizationMethod::UInt4 => {
let packedsize = length.div_ceil(2); let mut packed_data = Array1::zeros(packedsize);
for i in 0..length {
let val_f32: f32 = vector[i].as_();
let ival = ((val_f32 - min_val) / scale).round() as i32;
let q_val = (ival.clamp(0, 15) & 0x0F) as i8;
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data[byte_idx] = q_val << 4;
} else {
packed_data[byte_idx] |= q_val & 0x0F;
}
}
(
QuantizedVector::new_i8(packed_data, length, QuantizedDataType::UInt4),
params,
)
}
_ => {
let quantized_data = match method {
QuantizationMethod::Uniform => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized[i] = q_val;
}
quantized
}
QuantizationMethod::Symmetric => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = (val_f32 / scale).round() as i8;
quantized[i] = q_val;
}
quantized
}
QuantizationMethod::Affine => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 / scale) + zero_point as f32).round() as i8;
quantized[i] = q_val;
}
quantized
}
QuantizationMethod::PowerOfTwo => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized[i] = q_val;
}
quantized
}
_ => unreachable!(), };
(
QuantizedVector::new_i8(quantized_data, length, QuantizedDataType::Int8),
params,
)
}
}
}
pub fn dequantize_vector_public(
quantized: &QuantizedVector,
params: &QuantizationParams,
) -> Array1<f32> {
let length = quantized.len();
let mut dequantized = Array1::zeros(length);
match &quantized.data {
QuantizedData1D::Float16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized[i] = val.to_f32();
}
}
QuantizedData1D::BFloat16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized[i] = val.to_f32();
}
}
QuantizedData1D::Int8(data) => {
match quantized.data_type {
QuantizedDataType::Int4 | QuantizedDataType::UInt4 => {
for i in 0..length {
let q_val = quantized.get_i8(i);
let val = match params.method {
QuantizationMethod::Int4 => q_val as f32 * params.scale,
QuantizationMethod::UInt4 => {
params.min_val + (q_val as f32 * params.scale)
}
_ => unreachable!(), };
dequantized[i] = val;
}
}
QuantizedDataType::Int8 => {
match params.method {
QuantizationMethod::Uniform => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized[i] = val;
}
}
QuantizationMethod::Symmetric => {
for (i, &q_val) in data.iter().enumerate() {
let val = q_val as f32 * params.scale;
dequantized[i] = val;
}
}
QuantizationMethod::Affine => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.scale * (q_val as f32 - params.zero_point as f32);
dequantized[i] = val;
}
}
QuantizationMethod::PowerOfTwo => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized[i] = val;
}
}
_ => unreachable!(), }
}
_ => unreachable!(), }
}
}
dequantized
}
pub fn fake_quantize<F>(matrix: &ArrayView2<F>, bits: u8, method: QuantizationMethod) -> Array2<F>
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let (quantized, params) = quantize_matrix(matrix, bits, method);
let dequantized = dequantize_matrix(&quantized, ¶ms);
let mut result = Array2::zeros(matrix.dim());
for (i, &val) in dequantized.iter().enumerate() {
result.as_slice_mut().expect("Operation failed")[i] =
F::from_f32(val).expect("Operation failed");
}
result
}
pub fn fake_quantize_vector<F>(
vector: &ArrayView1<F>,
bits: u8,
method: QuantizationMethod,
) -> Array1<F>
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let (quantized, params) = quantize_vector(vector, bits, method);
let dequantized = dequantize_vector_public(&quantized, ¶ms);
let mut result = Array1::zeros(vector.dim());
for (i, &val) in dequantized.iter().enumerate() {
result[i] = F::from_f32(val).expect("Operation failed");
}
result
}