use half::{bf16, f16};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{AsPrimitive, Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{LinalgError, LinalgResult};
pub mod calibration;
pub mod calibration_ema;
pub mod fusion;
pub mod out_of_core;
pub mod quantized_matrixfree;
pub mod simd;
pub mod solvers;
pub mod stability;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantizationMethod {
Uniform,
Symmetric,
Affine,
PowerOfTwo,
Int4,
UInt4,
Float16,
BFloat16,
PerChannelSymmetric,
PerChannelAffine,
}
#[derive(Debug, Clone)]
pub struct QuantizationParams {
pub bits: u8,
pub scale: f32,
pub zero_point: i32,
pub min_val: f32,
pub max_val: f32,
pub method: QuantizationMethod,
pub data_type: QuantizedDataType,
pub channel_scales: Option<Vec<f32>>,
pub channel_zero_points: Option<Vec<i32>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuantizedDataType {
Int8,
Int4,
UInt4,
Float16,
BFloat16,
}
#[derive(Debug, Clone)]
pub struct QuantizedMatrix {
pub data: QuantizedData2D,
pub shape: (usize, usize),
pub data_type: QuantizedDataType,
}
#[derive(Debug, Clone)]
pub struct QuantizedVector {
pub data: QuantizedData1D,
pub length: usize,
pub data_type: QuantizedDataType,
}
#[derive(Debug, Clone)]
pub enum QuantizedData2D {
Int8(Array2<i8>),
Float16(Array2<f16>),
BFloat16(Array2<bf16>),
}
impl QuantizedData2D {
pub fn len(&self) -> usize {
match self {
QuantizedData2D::Int8(arr) => arr.len(),
QuantizedData2D::Float16(arr) => arr.len(),
QuantizedData2D::BFloat16(arr) => arr.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub enum QuantizedData1D {
Int8(Array1<i8>),
Float16(Array1<f16>),
BFloat16(Array1<bf16>),
}
#[allow(dead_code)]
pub fn get_quantizedmatrix_2d_i8(matrix: &QuantizedMatrix) -> Option<&Array2<i8>> {
match &matrix.data {
QuantizedData2D::Int8(data) => Some(data),
_ => None,
}
}
#[allow(dead_code)]
pub fn get_quantized_vector_1d_i8(vector: &QuantizedVector) -> Option<&Array1<i8>> {
match &vector.data {
QuantizedData1D::Int8(data) => Some(data),
_ => None,
}
}
impl QuantizedData1D {
pub fn len(&self) -> usize {
match self {
QuantizedData1D::Int8(arr) => arr.len(),
QuantizedData1D::Float16(arr) => arr.len(),
QuantizedData1D::BFloat16(arr) => arr.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl QuantizedMatrix {
pub fn new_i8(data: Array2<i8>, shape: (usize, usize), data_type: QuantizedDataType) -> Self {
Self {
data: QuantizedData2D::Int8(data),
shape,
data_type,
}
}
pub fn new_f16(data: Array2<f16>, shape: (usize, usize)) -> Self {
Self {
data: QuantizedData2D::Float16(data),
shape,
data_type: QuantizedDataType::Float16,
}
}
pub fn new_bf16(data: Array2<bf16>, shape: (usize, usize)) -> Self {
Self {
data: QuantizedData2D::BFloat16(data),
shape,
data_type: QuantizedDataType::BFloat16,
}
}
pub fn from_i8(data: Array2<i8>, shape: (usize, usize)) -> Self {
Self {
data: QuantizedData2D::Int8(data),
shape,
data_type: QuantizedDataType::Int8,
}
}
#[deprecated(since = "0.1.0", note = "Use get_i8 or get_f32 instead")]
pub fn get(&self, row: usize, col: usize) -> i8 {
self.get_i8(row, col)
}
pub fn shape(&self) -> (usize, usize) {
self.shape
}
pub fn nrows(&self) -> usize {
self.shape.0
}
pub fn ncols(&self) -> usize {
self.shape.1
}
pub fn get_i8(&self, row: usize, col: usize) -> i8 {
match &self.data {
QuantizedData2D::Int8(arr) => {
match self.data_type {
QuantizedDataType::Int8 => arr[[row, col]],
QuantizedDataType::Int4 => {
let idx = row * self.shape.1 + col;
let byte_idx = idx / 2;
let nibble_idx = idx % 2;
let byte = arr.as_slice().expect("Operation failed")[byte_idx];
if nibble_idx == 0 {
byte >> 4
} else {
byte & 0x0F
}
}
QuantizedDataType::UInt4 => {
let idx = row * self.shape.1 + col;
let byte_idx = idx / 2;
let nibble_idx = idx % 2;
let byte = arr.as_slice().expect("Operation failed")[byte_idx];
if nibble_idx == 0 {
(byte >> 4) & 0x0F
} else {
byte & 0x0F
}
}
_ => unreachable!(
"Invalid quantization type for Int8 storage: expected Int8, Int4, or UInt4"
),
}
}
_ => unreachable!("Cannot get i8 value from floating-point quantized matrix"),
}
}
pub fn get_f32(&self, row: usize, col: usize) -> f32 {
match &self.data {
QuantizedData2D::Int8(arr) => match self.data_type {
QuantizedDataType::Int8 => arr[[row, col]] as f32,
QuantizedDataType::Int4 => self.get_i8(row, col) as f32,
QuantizedDataType::UInt4 => self.get_i8(row, col) as f32,
_ => unreachable!(
"Invalid data type for Int8 storage: expected Int8, Int4, or UInt4"
),
},
QuantizedData2D::Float16(arr) => arr[[row, col]].to_f32(),
QuantizedData2D::BFloat16(arr) => arr[[row, col]].to_f32(),
}
}
}
impl QuantizedVector {
pub fn new_i8(data: Array1<i8>, length: usize, datatype: QuantizedDataType) -> Self {
Self {
data: QuantizedData1D::Int8(data),
length,
data_type: datatype,
}
}
pub fn new_f16(data: Array1<f16>, length: usize) -> Self {
Self {
data: QuantizedData1D::Float16(data),
length,
data_type: QuantizedDataType::Float16,
}
}
pub fn new_bf16(data: Array1<bf16>, length: usize) -> Self {
Self {
data: QuantizedData1D::BFloat16(data),
length,
data_type: QuantizedDataType::BFloat16,
}
}
pub fn from_i8(data: Array1<i8>, length: usize) -> Self {
Self {
data: QuantizedData1D::Int8(data),
length,
data_type: QuantizedDataType::Int8,
}
}
#[deprecated(since = "0.1.0", note = "Use get_i8 or get_f32 instead")]
pub fn get(&self, idx: usize) -> i8 {
self.get_i8(idx)
}
pub fn len(&self) -> usize {
self.length
}
pub fn is_empty(&self) -> bool {
self.length == 0
}
pub fn get_i8(&self, idx: usize) -> i8 {
match &self.data {
QuantizedData1D::Int8(arr) => {
match self.data_type {
QuantizedDataType::Int8 => arr[idx],
QuantizedDataType::Int4 => {
let byte_idx = idx / 2;
let nibble_idx = idx % 2;
let byte = arr[byte_idx];
if nibble_idx == 0 {
byte >> 4
} else {
byte & 0x0F
}
}
QuantizedDataType::UInt4 => {
let byte_idx = idx / 2;
let nibble_idx = idx % 2;
let byte = arr[byte_idx];
if nibble_idx == 0 {
(byte >> 4) & 0x0F
} else {
byte & 0x0F
}
}
_ => unreachable!(
"Invalid quantization type for Int8 storage: expected Int8, Int4, or UInt4"
),
}
}
_ => unreachable!("Cannot get i8 value from floating-point quantized vector"),
}
}
pub fn get_f32(&self, idx: usize) -> f32 {
match &self.data {
QuantizedData1D::Int8(arr) => match self.data_type {
QuantizedDataType::Int8 => arr[idx] as f32,
QuantizedDataType::Int4 => self.get_i8(idx) as f32,
QuantizedDataType::UInt4 => self.get_i8(idx) as f32,
_ => unreachable!(
"Invalid data type for Int8 storage: expected Int8, Int4, or UInt4"
),
},
QuantizedData1D::Float16(arr) => arr[idx].to_f32(),
QuantizedData1D::BFloat16(arr) => arr[idx].to_f32(),
}
}
}
#[allow(dead_code)]
pub fn quantize_matrix<F>(
matrix: &ArrayView2<F>,
bits: u8,
method: QuantizationMethod,
) -> (QuantizedMatrix, QuantizationParams)
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let shape = (matrix.nrows(), matrix.ncols());
let mut min_val = F::infinity().as_();
let mut max_val = F::neg_infinity().as_();
for &val in matrix.iter() {
let val_f32: f32 = val.as_();
if val_f32.is_finite() {
min_val = min_val.min(val_f32);
max_val = max_val.max(val_f32);
}
}
if (max_val - min_val).abs() < f32::EPSILON {
max_val = min_val + 1.0;
}
if method == QuantizationMethod::Float16 {
let mut f16_data = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
f16_data.as_slice_mut().expect("Operation failed")[i] = f16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::Float16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedMatrix::new_f16(f16_data, shape), params);
}
if method == QuantizationMethod::BFloat16 {
let mut bf16_data = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
bf16_data.as_slice_mut().expect("Operation failed")[i] = bf16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::BFloat16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedMatrix::new_bf16(bf16_data, shape), params);
}
let data_type = match method {
QuantizationMethod::Int4 => QuantizedDataType::Int4,
QuantizationMethod::UInt4 => QuantizedDataType::UInt4,
_ => QuantizedDataType::Int8,
};
let effective_bits = match method {
QuantizationMethod::Int4 | QuantizationMethod::UInt4 => 4,
_ => bits,
};
let (scale, zero_point) = match method {
QuantizationMethod::Uniform => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Symmetric => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / ((1 << (effective_bits - 1)) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Affine => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
QuantizationMethod::PowerOfTwo => {
let range = max_val - min_val;
let ideal_scale = range / ((1 << effective_bits) - 1) as f32;
let exponent = ideal_scale.log2().ceil();
let scale = 2.0_f32.powf(exponent);
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Int4 => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / 7.0; let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::UInt4 => {
let scale = (max_val - min_val) / 15.0; let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
_ => unreachable!(), };
let params = QuantizationParams {
bits: effective_bits,
scale,
zero_point,
min_val,
max_val,
method,
data_type,
channel_scales: None,
channel_zero_points: None,
};
match method {
QuantizationMethod::Int4 => {
let num_elements = matrix.len();
let mut packed_data = Array2::zeros((shape.0, shape.1.div_ceil(2)));
for i in 0..num_elements {
let val_f32: f32 = matrix.as_slice().expect("Operation failed")[i].as_();
let q_val = ((val_f32 / scale).round() as i8).clamp(-8, 7);
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] = q_val << 4;
} else {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] |= q_val & 0x0F;
}
}
let packedshape = (shape.0, shape.1.div_ceil(2));
let packed_reshaped = packed_data.into_shape_with_order(packedshape).expect("Operation failed");
(
QuantizedMatrix::new_i8(packed_reshaped, shape, QuantizedDataType::Int4),
params,
)
}
QuantizationMethod::UInt4 => {
let num_elements = matrix.len();
let mut packed_data = Array2::zeros((shape.0, shape.1.div_ceil(2)));
for i in 0..num_elements {
let val_f32: f32 = matrix.as_slice().expect("Operation failed")[i].as_();
let ival = ((val_f32 - min_val) / scale).round() as i32;
let q_val = (ival.clamp(0, 15) & 0x0F) as i8;
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] = q_val << 4;
} else {
packed_data.as_slice_mut().expect("Operation failed")[byte_idx] |= q_val & 0x0F;
}
}
let packedshape = (shape.0, shape.1.div_ceil(2));
let packed_reshaped = packed_data.into_shape_with_order(packedshape).expect("Operation failed");
(
QuantizedMatrix::new_i8(packed_reshaped, shape, QuantizedDataType::UInt4),
params,
)
}
_ => {
let quantized_data = match method {
QuantizationMethod::Uniform => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
QuantizationMethod::Symmetric => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = (val_f32 / scale).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
QuantizationMethod::Affine => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 / scale) + zero_point as f32).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
QuantizationMethod::PowerOfTwo => {
let mut quantized = Array2::zeros(shape);
for (i, &val) in matrix.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized.as_slice_mut().expect("Operation failed")[i] = q_val;
}
quantized
}
_ => unreachable!(), };
(
QuantizedMatrix::new_i8(quantized_data, shape, QuantizedDataType::Int8),
params,
)
}
}
}
#[allow(dead_code)]
pub fn quantize_matrix_per_channel<F>(
matrix: &ArrayView2<F>,
bits: u8,
method: QuantizationMethod,
) -> (QuantizedMatrix, QuantizationParams)
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
assert!(
method == QuantizationMethod::PerChannelSymmetric
|| method == QuantizationMethod::PerChannelAffine,
"quantize_matrix_per_channel requires PerChannelSymmetric or PerChannelAffine method, got {method:?}"
);
let shape = (matrix.nrows(), matrix.ncols());
let num_channels = shape.1;
let data_type = QuantizedDataType::Int8.clone();
let mut channel_min_vals = vec![F::infinity().as_(); num_channels];
let mut channel_max_vals = vec![F::neg_infinity().as_(); num_channels];
for col in 0..num_channels {
for row in 0..shape.0 {
let val_f32: f32 = matrix[[row, col]].as_();
if val_f32.is_finite() {
channel_min_vals[col] = channel_min_vals[col].min(val_f32);
channel_max_vals[col] = channel_max_vals[col].max(val_f32);
}
}
if (channel_max_vals[col] - channel_min_vals[col]).abs() < f32::EPSILON {
channel_max_vals[col] = channel_min_vals[col] + 1.0;
}
}
let min_val = channel_min_vals
.iter()
.fold(F::infinity().as_(), |acc, &val| acc.min(val));
let max_val = channel_max_vals
.iter()
.fold(F::neg_infinity().as_(), |acc, &val| acc.max(val));
let mut channel_scales = vec![0.0; num_channels];
let mut channel_zero_points = vec![0; num_channels];
match method {
QuantizationMethod::PerChannelSymmetric => {
for col in 0..num_channels {
let abs_max = channel_max_vals[col].abs().max(channel_min_vals[col].abs());
channel_scales[col] = abs_max / ((1 << (bits - 1)) - 1) as f32;
channel_zero_points[col] = 0; }
}
QuantizationMethod::PerChannelAffine => {
for col in 0..num_channels {
channel_scales[col] =
(channel_max_vals[col] - channel_min_vals[col]) / ((1 << bits) - 1) as f32;
channel_zero_points[col] =
(-channel_min_vals[col] / channel_scales[col]).round() as i32;
}
}
_ => unreachable!(),
}
let scale = channel_scales.iter().sum::<f32>() / num_channels as f32;
let zero_point = if method == QuantizationMethod::PerChannelAffine {
(channel_zero_points.iter().sum::<i32>() as f32 / num_channels as f32).round() as i32
} else {
0
};
let params = QuantizationParams {
bits,
scale,
zero_point,
min_val,
max_val,
method,
data_type: data_type.clone(),
channel_scales: Some(channel_scales.clone()),
channel_zero_points: Some(channel_zero_points.clone()),
};
let mut quantized_data = Array2::zeros(shape);
for col in 0..num_channels {
let scale = channel_scales[col];
let zero_point = channel_zero_points[col];
for row in 0..shape.0 {
let val_f32: f32 = matrix[[row, col]].as_();
let q_val = match method {
QuantizationMethod::PerChannelSymmetric => {
(val_f32 / scale)
.round()
.clamp(-(1 << (bits - 1)) as f32, ((1 << (bits - 1)) - 1) as f32)
as i8
}
QuantizationMethod::PerChannelAffine => {
((val_f32 / scale) + zero_point as f32)
.round()
.clamp(0.0, ((1 << bits) - 1) as f32) as i8
}
_ => unreachable!(),
};
quantized_data[[row, col]] = q_val;
}
}
(
QuantizedMatrix::new_i8(quantized_data, shape, data_type.clone()),
params,
)
}
#[allow(dead_code)]
pub fn dequantize_matrix(quantized: &QuantizedMatrix, params: &QuantizationParams) -> Array2<f32> {
let shape = quantized.shape();
let mut dequantized = Array2::zeros(shape);
match &quantized.data {
QuantizedData2D::Float16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized.as_slice_mut().expect("Operation failed")[i] = val.to_f32();
}
}
QuantizedData2D::BFloat16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized.as_slice_mut().expect("Operation failed")[i] = val.to_f32();
}
}
QuantizedData2D::Int8(data) => {
match quantized.data_type {
QuantizedDataType::Int4 | QuantizedDataType::UInt4 => {
let num_elements = shape.0 * shape.1;
for i in 0..num_elements {
let row = i / shape.1;
let col = i % shape.1;
let q_val = quantized.get_i8(row, col);
let val = match params.method {
QuantizationMethod::Int4 => q_val as f32 * params.scale,
QuantizationMethod::UInt4 => {
params.min_val + (q_val as f32 * params.scale)
}
_ => unreachable!(), };
dequantized[[row, col]] = val;
}
}
QuantizedDataType::Int8
if params.method == QuantizationMethod::PerChannelSymmetric
|| params.method == QuantizationMethod::PerChannelAffine =>
{
let channel_scales = params
.channel_scales
.as_ref()
.expect("Per-channel quantization requires channel_scales");
let channel_zero_points = params
.channel_zero_points
.as_ref()
.expect("Per-channel quantization requires channel_zero_points");
let num_channels = shape.1;
for row in 0..shape.0 {
for col in 0..num_channels {
let q_val = data[[row, col]];
let scale = channel_scales[col];
let zero_point = channel_zero_points[col];
let val = match params.method {
QuantizationMethod::PerChannelSymmetric => {
q_val as f32 * scale
}
QuantizationMethod::PerChannelAffine => {
scale * (q_val as f32 - zero_point as f32)
}
_ => unreachable!(), };
dequantized[[row, col]] = val;
}
}
}
QuantizedDataType::Int8 => {
match params.method {
QuantizationMethod::Uniform => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
QuantizationMethod::Symmetric => {
for (i, &q_val) in data.iter().enumerate() {
let val = q_val as f32 * params.scale;
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
QuantizationMethod::Affine => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.scale * (q_val as f32 - params.zero_point as f32);
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
QuantizationMethod::PowerOfTwo => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized.as_slice_mut().expect("Operation failed")[i] = val;
}
}
_ => unreachable!(), }
}
_ => unreachable!(), }
}
}
dequantized
}
#[allow(dead_code)]
pub fn quantize_vector<F>(
vector: &ArrayView1<F>,
bits: u8,
method: QuantizationMethod,
) -> (QuantizedVector, QuantizationParams)
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let length = vector.len();
let mut min_val = F::infinity().as_();
let mut max_val = F::neg_infinity().as_();
for &val in vector.iter() {
let val_f32: f32 = val.as_();
if val_f32.is_finite() {
min_val = min_val.min(val_f32);
max_val = max_val.max(val_f32);
}
}
if (max_val - min_val).abs() < f32::EPSILON {
max_val = min_val + 1.0;
}
if method == QuantizationMethod::Float16 {
let mut f16_data = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
f16_data[i] = f16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::Float16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedVector::new_f16(f16_data, length), params);
}
if method == QuantizationMethod::BFloat16 {
let mut bf16_data = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
bf16_data[i] = bf16::from_f32(val_f32);
}
let params = QuantizationParams {
bits: 16,
scale: 1.0, zero_point: 0,
min_val,
max_val,
method,
data_type: QuantizedDataType::BFloat16,
channel_scales: None,
channel_zero_points: None,
};
return (QuantizedVector::new_bf16(bf16_data, length), params);
}
let data_type = match method {
QuantizationMethod::Int4 => QuantizedDataType::Int4,
QuantizationMethod::UInt4 => QuantizedDataType::UInt4,
_ => QuantizedDataType::Int8,
};
let effective_bits = match method {
QuantizationMethod::Int4 | QuantizationMethod::UInt4 => 4,
_ => bits,
};
let (scale, zero_point) = match method {
QuantizationMethod::Uniform => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Symmetric => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / ((1 << (effective_bits - 1)) - 1) as f32;
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Affine => {
let scale = (max_val - min_val) / ((1 << effective_bits) - 1) as f32;
let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
QuantizationMethod::PowerOfTwo => {
let range = max_val - min_val;
let ideal_scale = range / ((1 << effective_bits) - 1) as f32;
let exponent = ideal_scale.log2().ceil();
let scale = 2.0_f32.powf(exponent);
let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::Int4 => {
let abs_max = max_val.abs().max(min_val.abs());
let scale = abs_max / 7.0; let zero_point = 0;
(scale, zero_point)
}
QuantizationMethod::UInt4 => {
let scale = (max_val - min_val) / 15.0; let zero_point = (-min_val / scale).round() as i32;
(scale, zero_point)
}
_ => unreachable!(), };
let params = QuantizationParams {
bits: effective_bits,
scale,
zero_point,
min_val,
max_val,
method,
data_type,
channel_scales: None,
channel_zero_points: None,
};
match method {
QuantizationMethod::Int4 => {
let packedsize = length.div_ceil(2); let mut packed_data = Array1::zeros(packedsize);
for i in 0..length {
let val_f32: f32 = vector[i].as_();
let q_val = ((val_f32 / scale).round() as i8).clamp(-8, 7);
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data[byte_idx] = q_val << 4;
} else {
packed_data[byte_idx] |= q_val & 0x0F;
}
}
(
QuantizedVector::new_i8(packed_data, length, QuantizedDataType::Int4),
params,
)
}
QuantizationMethod::UInt4 => {
let packedsize = length.div_ceil(2); let mut packed_data = Array1::zeros(packedsize);
for i in 0..length {
let val_f32: f32 = vector[i].as_();
let ival = ((val_f32 - min_val) / scale).round() as i32;
let q_val = (ival.clamp(0, 15) & 0x0F) as i8;
let byte_idx = i / 2;
if i % 2 == 0 {
packed_data[byte_idx] = q_val << 4;
} else {
packed_data[byte_idx] |= q_val & 0x0F;
}
}
(
QuantizedVector::new_i8(packed_data, length, QuantizedDataType::UInt4),
params,
)
}
_ => {
let quantized_data = match method {
QuantizationMethod::Uniform => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized[i] = q_val;
}
quantized
}
QuantizationMethod::Symmetric => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = (val_f32 / scale).round() as i8;
quantized[i] = q_val;
}
quantized
}
QuantizationMethod::Affine => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 / scale) + zero_point as f32).round() as i8;
quantized[i] = q_val;
}
quantized
}
QuantizationMethod::PowerOfTwo => {
let mut quantized = Array1::zeros(length);
for (i, &val) in vector.iter().enumerate() {
let val_f32: f32 = val.as_();
let q_val = ((val_f32 - min_val) / scale).round() as i8;
quantized[i] = q_val;
}
quantized
}
_ => unreachable!(), };
(
QuantizedVector::new_i8(quantized_data, length, QuantizedDataType::Int8),
params,
)
}
}
}
#[allow(dead_code)]
pub fn dequantize_vector(quantized: &QuantizedVector, params: &QuantizationParams) -> Array1<f32> {
let length = quantized.len();
let mut dequantized = Array1::zeros(length);
match &quantized.data {
QuantizedData1D::Float16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized[i] = val.to_f32();
}
}
QuantizedData1D::BFloat16(data) => {
for (i, &val) in data.iter().enumerate() {
dequantized[i] = val.to_f32();
}
}
QuantizedData1D::Int8(data) => {
match quantized.data_type {
QuantizedDataType::Int4 | QuantizedDataType::UInt4 => {
for i in 0..length {
let q_val = quantized.get_i8(i);
let val = match params.method {
QuantizationMethod::Int4 => q_val as f32 * params.scale,
QuantizationMethod::UInt4 => {
params.min_val + (q_val as f32 * params.scale)
}
_ => unreachable!(), };
dequantized[i] = val;
}
}
QuantizedDataType::Int8 => {
match params.method {
QuantizationMethod::Uniform => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized[i] = val;
}
}
QuantizationMethod::Symmetric => {
for (i, &q_val) in data.iter().enumerate() {
let val = q_val as f32 * params.scale;
dequantized[i] = val;
}
}
QuantizationMethod::Affine => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.scale * (q_val as f32 - params.zero_point as f32);
dequantized[i] = val;
}
}
QuantizationMethod::PowerOfTwo => {
for (i, &q_val) in data.iter().enumerate() {
let val = params.min_val + (q_val as f32 * params.scale);
dequantized[i] = val;
}
}
_ => unreachable!(), }
}
_ => unreachable!(), }
}
}
dequantized
}
#[allow(dead_code)]
pub fn quantized_matmul(
a: &QuantizedMatrix,
a_params: &QuantizationParams,
b: &QuantizedMatrix,
b_params: &QuantizationParams,
) -> LinalgResult<Array2<f32>> {
if a.ncols() != b.nrows() {
return Err(LinalgError::DimensionError(format!(
"Cannot multiply matrices with shapes {:?} and {:?}",
a.shape(),
b.shape()
)));
}
let (m, k) = a.shape();
let (_, n) = b.shape();
let mut result = Array2::zeros((m, n));
if matches!(
a.data_type,
QuantizedDataType::Float16 | QuantizedDataType::BFloat16
) || matches!(
b.data_type,
QuantizedDataType::Float16 | QuantizedDataType::BFloat16
) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0_f32;
for l in 0..k {
let a_val = a.get_f32(i, l);
let b_val = b.get_f32(l, j);
sum += a_val * b_val;
}
result[[i, j]] = sum;
}
}
return Ok(result);
}
let a_per_channel = a_params.method == QuantizationMethod::PerChannelSymmetric
|| a_params.method == QuantizationMethod::PerChannelAffine;
let b_per_channel = b_params.method == QuantizationMethod::PerChannelSymmetric
|| b_params.method == QuantizationMethod::PerChannelAffine;
if a_per_channel || b_per_channel {
let a_dequant = dequantize_matrix(a, a_params);
let b_dequant = dequantize_matrix(b, b_params);
for i in 0..m {
for j in 0..n {
let mut sum = 0.0_f32;
for l in 0..k {
sum += a_dequant[[i, l]] * b_dequant[[l, j]];
}
result[[i, j]] = sum;
}
}
return Ok(result);
}
for i in 0..m {
for j in 0..n {
let mut sum = 0i32;
for l in 0..k {
let a_val = a.get_i8(i, l) as i32;
let b_val = b.get_i8(l, j) as i32;
sum += a_val * b_val;
}
let a_scale = a_params.scale;
let b_scale = b_params.scale;
if (a_params.method == QuantizationMethod::Affine
|| a_params.method == QuantizationMethod::UInt4)
&& (b_params.method == QuantizationMethod::Affine
|| b_params.method == QuantizationMethod::UInt4)
{
let a_zero_sum: i32 =
(0..k).map(|l| b.get_i8(l, j) as i32).sum::<i32>() * a_params.zero_point;
let b_zero_sum: i32 =
(0..k).map(|l| a.get_i8(i, l) as i32).sum::<i32>() * b_params.zero_point;
let zero_product = k as i32 * a_params.zero_point * b_params.zero_point;
sum = sum - a_zero_sum - b_zero_sum + zero_product;
}
result[[i, j]] = sum as f32 * a_scale * b_scale;
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn quantized_matvec(
a: &QuantizedMatrix,
a_params: &QuantizationParams,
x: &QuantizedVector,
x_params: &QuantizationParams,
) -> LinalgResult<Array1<f32>> {
if a.ncols() != x.len() {
return Err(LinalgError::DimensionError(format!(
"Cannot multiply matrix with shape {:?} and vector with length {}",
a.shape(),
x.len()
)));
}
let (m, n) = a.shape();
let mut result = Array1::zeros(m);
if matches!(
a.data_type,
QuantizedDataType::Float16 | QuantizedDataType::BFloat16
) || matches!(
x.data_type,
QuantizedDataType::Float16 | QuantizedDataType::BFloat16
) {
for i in 0..m {
let mut sum = 0.0_f32;
for j in 0..n {
let a_val = a.get_f32(i, j);
let x_val = x.get_f32(j);
sum += a_val * x_val;
}
result[i] = sum;
}
return Ok(result);
}
let a_per_channel = a_params.method == QuantizationMethod::PerChannelSymmetric
|| a_params.method == QuantizationMethod::PerChannelAffine;
if a_per_channel {
let a_dequant = dequantize_matrix(a, a_params);
let x_dequant = dequantize_vector(x, x_params);
for i in 0..m {
let mut sum = 0.0_f32;
for j in 0..n {
sum += a_dequant[[i, j]] * x_dequant[j];
}
result[i] = sum;
}
return Ok(result);
}
for i in 0..m {
let mut sum = 0i32;
for j in 0..n {
let a_val = a.get_i8(i, j) as i32;
let x_val = x.get_i8(j) as i32;
sum += a_val * x_val;
}
let a_scale = a_params.scale;
let x_scale = x_params.scale;
if (a_params.method == QuantizationMethod::Affine
|| a_params.method == QuantizationMethod::UInt4)
&& (x_params.method == QuantizationMethod::Affine
|| x_params.method == QuantizationMethod::UInt4)
{
let a_zero_sum: i32 =
(0..n).map(|j| x.get_i8(j) as i32).sum::<i32>() * a_params.zero_point;
let x_zero_sum: i32 =
(0..n).map(|j| a.get_i8(i, j) as i32).sum::<i32>() * x_params.zero_point;
let zero_product = n as i32 * a_params.zero_point * x_params.zero_point;
sum = sum - a_zero_sum - x_zero_sum + zero_product;
}
result[i] = sum as f32 * a_scale * x_scale;
}
Ok(result)
}
#[allow(dead_code)]
pub fn quantized_dot(
a: &QuantizedVector,
a_params: &QuantizationParams,
b: &QuantizedVector,
b_params: &QuantizationParams,
) -> LinalgResult<f32> {
if a.len() != b.len() {
return Err(LinalgError::DimensionError(format!(
"Cannot compute dot product of vectors with lengths {} and {}",
a.len(),
b.len()
)));
}
let n = a.len();
if matches!(
a.data_type,
QuantizedDataType::Float16 | QuantizedDataType::BFloat16
) || matches!(
b.data_type,
QuantizedDataType::Float16 | QuantizedDataType::BFloat16
) {
let mut sum = 0.0_f32;
for i in 0..n {
let a_val = a.get_f32(i);
let b_val = b.get_f32(i);
sum += a_val * b_val;
}
return Ok(sum);
}
let a_per_channel = a_params.method == QuantizationMethod::PerChannelSymmetric
|| a_params.method == QuantizationMethod::PerChannelAffine;
let b_per_channel = b_params.method == QuantizationMethod::PerChannelSymmetric
|| b_params.method == QuantizationMethod::PerChannelAffine;
if a_per_channel || b_per_channel {
let a_dequant = dequantize_vector(a, a_params);
let b_dequant = dequantize_vector(b, b_params);
let mut sum = 0.0_f32;
for i in 0..n {
sum += a_dequant[i] * b_dequant[i];
}
return Ok(sum);
}
let mut sum = 0i32;
for i in 0..n {
let a_val = a.get_i8(i) as i32;
let b_val = b.get_i8(i) as i32;
sum += a_val * b_val;
}
let a_scale = a_params.scale;
let b_scale = b_params.scale;
if (a_params.method == QuantizationMethod::Affine
|| a_params.method == QuantizationMethod::UInt4)
&& (b_params.method == QuantizationMethod::Affine
|| b_params.method == QuantizationMethod::UInt4)
{
let a_zero_sum: i32 = (0..n).map(|i| b.get_i8(i) as i32).sum::<i32>() * a_params.zero_point;
let b_zero_sum: i32 = (0..n).map(|i| a.get_i8(i) as i32).sum::<i32>() * b_params.zero_point;
let zero_product = n as i32 * a_params.zero_point * b_params.zero_point;
sum = sum - a_zero_sum - b_zero_sum + zero_product;
}
let result = sum as f32 * a_scale * b_scale;
Ok(result)
}
#[allow(dead_code)]
pub fn fake_quantize<F>(matrix: &ArrayView2<F>, bits: u8, method: QuantizationMethod) -> Array2<F>
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let (quantized, params) = quantize_matrix(matrix, bits, method);
let dequantized = dequantize_matrix(&quantized, ¶ms);
let mut result = Array2::zeros(matrix.dim());
for (i, &val) in dequantized.iter().enumerate() {
result.as_slice_mut().expect("Operation failed")[i] = F::from_f32(val).expect("Operation failed");
}
result
}
#[allow(dead_code)]
pub fn fake_quantize_vector<F>(
vector: &ArrayView1<F>,
bits: u8,
method: QuantizationMethod,
) -> Array1<F>
where
F: Float + Debug + AsPrimitive<f32> + FromPrimitive,
f32: AsPrimitive<F>,
{
let (quantized, params) = quantize_vector(vector, bits, method);
let dequantized = dequantize_vector(&quantized, ¶ms);
let mut result = Array1::zeros(vector.dim());
for (i, &val) in dequantized.iter().enumerate() {
result[i] = F::from_f32(val).expect("Operation failed");
}
result
}