use crate::{BackendResult, Device};
use torsh_core::error::TorshError;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuantizedDType {
Int8,
UInt8,
Int16,
UInt16,
Int4,
UInt4,
Binary,
Mixed(Vec<u8>),
}
impl QuantizedDType {
pub fn bits(&self) -> u8 {
match self {
QuantizedDType::Int8 | QuantizedDType::UInt8 => 8,
QuantizedDType::Int16 | QuantizedDType::UInt16 => 16,
QuantizedDType::Int4 | QuantizedDType::UInt4 => 4,
QuantizedDType::Binary => 1,
QuantizedDType::Mixed(bits) => bits.iter().max().copied().unwrap_or(8),
}
}
pub fn is_signed(&self) -> bool {
matches!(
self,
QuantizedDType::Int8 | QuantizedDType::Int16 | QuantizedDType::Int4
)
}
pub fn value_range(&self) -> (i64, i64) {
match self {
QuantizedDType::Int8 => (-128, 127),
QuantizedDType::UInt8 => (0, 255),
QuantizedDType::Int16 => (-32768, 32767),
QuantizedDType::UInt16 => (0, 65535),
QuantizedDType::Int4 => (-8, 7),
QuantizedDType::UInt4 => (0, 15),
QuantizedDType::Binary => (0, 1),
QuantizedDType::Mixed(_) => (0, 255), }
}
pub fn bytes_per_element(&self) -> usize {
match self {
QuantizedDType::Int8 | QuantizedDType::UInt8 => 1,
QuantizedDType::Int16 | QuantizedDType::UInt16 => 2,
QuantizedDType::Int4 | QuantizedDType::UInt4 => 1, QuantizedDType::Binary => 1, QuantizedDType::Mixed(_) => 1, }
}
pub fn is_packed(&self) -> bool {
matches!(
self,
QuantizedDType::Int4 | QuantizedDType::UInt4 | QuantizedDType::Binary
)
}
pub fn packing_factor(&self) -> usize {
match self {
QuantizedDType::Int4 | QuantizedDType::UInt4 => 2,
QuantizedDType::Binary => 8,
_ => 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantizationScheme {
Linear,
Logarithmic,
Symmetric,
Asymmetric,
BlockWise,
ChannelWise,
}
impl QuantizationScheme {
pub fn supports_zero_point(&self) -> bool {
matches!(
self,
QuantizationScheme::Asymmetric | QuantizationScheme::Linear
)
}
pub fn is_per_channel(&self) -> bool {
matches!(self, QuantizationScheme::ChannelWise)
}
pub fn is_block_wise(&self) -> bool {
matches!(self, QuantizationScheme::BlockWise)
}
}
#[derive(Debug, Clone)]
pub struct QuantizationParams {
pub dtype: QuantizedDType,
pub scheme: QuantizationScheme,
pub scale: Vec<f32>,
pub zero_point: Vec<i32>,
pub block_size: Option<usize>,
pub min_val: Option<f32>,
pub max_val: Option<f32>,
}
impl Default for QuantizationParams {
fn default() -> Self {
Self {
dtype: QuantizedDType::UInt8,
scheme: QuantizationScheme::Linear,
scale: vec![1.0],
zero_point: vec![0],
block_size: None,
min_val: None,
max_val: None,
}
}
}
impl QuantizationParams {
pub fn int8_symmetric() -> Self {
Self {
dtype: QuantizedDType::Int8,
scheme: QuantizationScheme::Symmetric,
scale: vec![1.0],
zero_point: vec![0],
block_size: None,
min_val: None,
max_val: None,
}
}
pub fn uint8_asymmetric() -> Self {
Self {
dtype: QuantizedDType::UInt8,
scheme: QuantizationScheme::Asymmetric,
scale: vec![1.0],
zero_point: vec![128],
block_size: None,
min_val: None,
max_val: None,
}
}
pub fn int4_symmetric() -> Self {
Self {
dtype: QuantizedDType::Int4,
scheme: QuantizationScheme::Symmetric,
scale: vec![1.0],
zero_point: vec![0],
block_size: None,
min_val: None,
max_val: None,
}
}
pub fn binary() -> Self {
Self {
dtype: QuantizedDType::Binary,
scheme: QuantizationScheme::Symmetric,
scale: vec![1.0],
zero_point: vec![0],
block_size: None,
min_val: None,
max_val: None,
}
}
pub fn channel_wise(num_channels: usize, dtype: QuantizedDType) -> Self {
Self {
dtype,
scheme: QuantizationScheme::ChannelWise,
scale: vec![1.0; num_channels],
zero_point: vec![0; num_channels],
block_size: None,
min_val: None,
max_val: None,
}
}
pub fn block_wise(block_size: usize, dtype: QuantizedDType) -> Self {
Self {
dtype,
scheme: QuantizationScheme::BlockWise,
scale: vec![1.0],
zero_point: vec![0],
block_size: Some(block_size),
min_val: None,
max_val: None,
}
}
pub fn from_statistics(&mut self, min_val: f32, max_val: f32) -> BackendResult<()> {
self.min_val = Some(min_val);
self.max_val = Some(max_val);
let (qmin, qmax) = self.dtype.value_range();
let qmin = qmin as f32;
let qmax = qmax as f32;
match self.scheme {
QuantizationScheme::Symmetric => {
let max_range = max_val.abs().max(min_val.abs());
if max_range == 0.0 {
self.scale[0] = 1.0;
} else {
self.scale[0] = (2.0 * max_range) / (qmax - qmin);
}
self.zero_point[0] = 0;
}
QuantizationScheme::Asymmetric => {
if max_val == min_val {
self.scale[0] = 1.0;
self.zero_point[0] = 0;
} else {
self.scale[0] = (max_val - min_val) / (qmax - qmin);
self.zero_point[0] = (qmin - min_val / self.scale[0]).round() as i32;
}
}
_ => {
if max_val == min_val {
self.scale[0] = 1.0;
self.zero_point[0] = 0;
} else {
self.scale[0] = (max_val - min_val) / (qmax - qmin);
self.zero_point[0] = (qmin - min_val / self.scale[0]).round() as i32;
}
}
}
Ok(())
}
pub fn validate(&self) -> BackendResult<()> {
if self.scheme.is_per_channel() {
if self.scale.is_empty() || self.zero_point.is_empty() {
return Err(TorshError::dimension_error(
"Channel-wise quantization requires non-empty scale and zero_point vectors",
"validate",
)
.into());
}
if self.scale.len() != self.zero_point.len() {
return Err(TorshError::dimension_error(
"Scale and zero_point vectors must have the same length for channel-wise quantization",
"validate"
).into());
}
} else {
if self.scale.len() != 1 || self.zero_point.len() != 1 {
return Err(TorshError::dimension_error(
"Non-channel-wise quantization requires exactly one scale and zero_point value",
"validate",
)
.into());
}
}
if self.scheme.is_block_wise() && self.block_size.is_none() {
return Err(TorshError::dimension_error(
"Block-wise quantization requires a block_size",
"validate",
)
.into());
}
for &scale in &self.scale {
if scale <= 0.0 || !scale.is_finite() {
return Err(TorshError::dimension_error(
"Scale values must be positive and finite",
"validate",
)
.into());
}
}
Ok(())
}
pub fn memory_overhead(&self) -> usize {
let scale_size = self.scale.len() * std::mem::size_of::<f32>();
let zero_point_size = self.zero_point.len() * std::mem::size_of::<i32>();
scale_size + zero_point_size + std::mem::size_of::<Self>()
}
pub fn new(scale: f32, zero_point: i32) -> Self {
Self {
dtype: QuantizedDType::Int8,
scheme: QuantizationScheme::Asymmetric,
scale: vec![scale],
zero_point: vec![zero_point],
block_size: None,
min_val: None,
max_val: None,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub data: Vec<u8>,
pub shape: Vec<usize>,
pub params: QuantizationParams,
pub device: Device,
}
impl QuantizedTensor {
pub fn new(
shape: Vec<usize>,
params: QuantizationParams,
device: Device,
) -> BackendResult<Self> {
params.validate()?;
let total_elements: usize = shape.iter().product();
let bytes_per_element = params.dtype.bytes_per_element();
let data_size = if params.dtype.is_packed() {
let packing_factor = params.dtype.packing_factor();
(total_elements + packing_factor - 1) / packing_factor
} else {
total_elements * bytes_per_element
};
Ok(Self {
data: vec![0; data_size],
shape,
params,
device,
})
}
pub fn from_data(
data: Vec<u8>,
shape: Vec<usize>,
params: QuantizationParams,
device: Device,
) -> BackendResult<Self> {
params.validate()?;
let tensor = Self {
data,
shape,
params,
device,
};
tensor.validate_data_size()?;
Ok(tensor)
}
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
pub fn memory_usage(&self) -> usize {
self.data.len() + self.params.memory_overhead()
}
pub fn data_size(&self) -> usize {
self.data.len()
}
pub fn compression_ratio(&self) -> f32 {
let fp32_size = self.num_elements() * 4; let quantized_size = self.data_size();
fp32_size as f32 / quantized_size as f32
}
fn validate_data_size(&self) -> BackendResult<()> {
let total_elements = self.num_elements();
let expected_size = if self.params.dtype.is_packed() {
let packing_factor = self.params.dtype.packing_factor();
(total_elements + packing_factor - 1) / packing_factor
} else {
total_elements * self.params.dtype.bytes_per_element()
};
if self.data.len() != expected_size {
return Err(TorshError::dimension_error(
&format!(
"Data size mismatch: expected {} bytes, got {}",
expected_size,
self.data.len()
),
"validate_memory_layout",
)
.into());
}
Ok(())
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn params(&self) -> &QuantizationParams {
&self.params
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn is_cpu(&self) -> bool {
matches!(self.device.device_type, torsh_core::device::DeviceType::Cpu)
}
pub fn is_gpu(&self) -> bool {
matches!(
self.device.device_type,
torsh_core::device::DeviceType::Cuda(_) | torsh_core::device::DeviceType::Metal(_)
)
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [u8] {
&mut self.data
}
}
pub fn quantize_to_int8(data: &[f32], params: &QuantizationParams) -> Result<Vec<i8>, TorshError> {
let scale = params.scale[0];
let zero_point = params.zero_point[0] as i8;
let quantized = data
.iter()
.map(|&x| {
let scaled = (x / scale).round() as i32 + zero_point as i32;
scaled.clamp(-128, 127) as i8
})
.collect();
Ok(quantized)
}
pub fn dequantize_from_int8(
data: &[i8],
params: &QuantizationParams,
) -> Result<Vec<f32>, TorshError> {
let scale = params.scale[0];
let zero_point = params.zero_point[0] as i8;
let dequantized = data
.iter()
.map(|&x| (x - zero_point) as f32 * scale)
.collect();
Ok(dequantized)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantized_dtype_bits() {
assert_eq!(QuantizedDType::Int8.bits(), 8);
assert_eq!(QuantizedDType::UInt8.bits(), 8);
assert_eq!(QuantizedDType::Int16.bits(), 16);
assert_eq!(QuantizedDType::Int4.bits(), 4);
assert_eq!(QuantizedDType::Binary.bits(), 1);
assert_eq!(QuantizedDType::Mixed(vec![4, 8, 16]).bits(), 16);
}
#[test]
fn test_quantized_dtype_signed() {
assert!(QuantizedDType::Int8.is_signed());
assert!(!QuantizedDType::UInt8.is_signed());
assert!(QuantizedDType::Int4.is_signed());
assert!(!QuantizedDType::UInt4.is_signed());
}
#[test]
fn test_quantized_dtype_value_range() {
assert_eq!(QuantizedDType::Int8.value_range(), (-128, 127));
assert_eq!(QuantizedDType::UInt8.value_range(), (0, 255));
assert_eq!(QuantizedDType::Int4.value_range(), (-8, 7));
assert_eq!(QuantizedDType::Binary.value_range(), (0, 1));
}
#[test]
fn test_quantization_scheme_properties() {
assert!(QuantizationScheme::Asymmetric.supports_zero_point());
assert!(!QuantizationScheme::Symmetric.supports_zero_point());
assert!(QuantizationScheme::ChannelWise.is_per_channel());
assert!(QuantizationScheme::BlockWise.is_block_wise());
}
#[test]
fn test_quantization_params_creation() {
let params = QuantizationParams::int8_symmetric();
assert_eq!(params.dtype, QuantizedDType::Int8);
assert_eq!(params.scheme, QuantizationScheme::Symmetric);
assert_eq!(params.zero_point[0], 0);
let params = QuantizationParams::uint8_asymmetric();
assert_eq!(params.dtype, QuantizedDType::UInt8);
assert_eq!(params.scheme, QuantizationScheme::Asymmetric);
assert_eq!(params.zero_point[0], 128);
}
#[test]
fn test_quantization_params_validation() {
let mut params = QuantizationParams::default();
assert!(params.validate().is_ok());
params.scale[0] = 0.0;
assert!(params.validate().is_err());
params.scale[0] = f32::NAN;
assert!(params.validate().is_err());
}
#[test]
fn test_quantized_tensor_creation() {
let params = QuantizationParams::uint8_asymmetric();
let tensor = QuantizedTensor::new(
vec![2, 3, 4],
params,
Device::cpu().expect("Quantized Tensor should succeed"),
);
assert!(tensor.is_ok());
let tensor = tensor.expect("operation should succeed");
assert_eq!(tensor.num_elements(), 24);
assert_eq!(tensor.shape(), &[2, 3, 4]);
assert!(tensor.is_cpu());
}
#[test]
fn test_compression_ratio() {
let params = QuantizationParams::uint8_asymmetric();
let tensor = QuantizedTensor::new(
vec![10, 10],
params,
Device::cpu().expect("Quantized Tensor should succeed"),
)
.expect("operation should succeed");
assert_eq!(tensor.compression_ratio(), 4.0);
}
#[test]
fn test_packed_types() {
assert!(QuantizedDType::Int4.is_packed());
assert!(QuantizedDType::Binary.is_packed());
assert!(!QuantizedDType::Int8.is_packed());
assert_eq!(QuantizedDType::Int4.packing_factor(), 2);
assert_eq!(QuantizedDType::Binary.packing_factor(), 8);
assert_eq!(QuantizedDType::Int8.packing_factor(), 1);
}
}