use super::types::{QuantizableInteger, QuantizedTensor};
use crate::error::{RusTorchError, RusTorchResult};
use ndarray::{ArrayD, Zip};
use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
use std::ops::{Add, Div, Mul, Sub};
pub trait QuantizedOps<Q: QuantizableInteger> {
fn qadd(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>>;
fn qsub(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>>;
fn qmul(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>>;
fn qmatmul(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>>;
fn qrelu(&self) -> RusTorchResult<QuantizedTensor<Q>>;
fn qadd_scalar(&self, scalar: f32) -> RusTorchResult<QuantizedTensor<Q>>;
fn qmul_scalar(&self, scalar: f32) -> RusTorchResult<QuantizedTensor<Q>>;
}
impl<Q: QuantizableInteger> QuantizedOps<Q> for QuantizedTensor<Q> {
fn qadd(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>> {
if self.shape() != other.shape() {
return Err(RusTorchError::ShapeMismatch {
expected: self.shape().to_vec(),
actual: other.shape().to_vec(),
});
}
let (result_scale, result_zero_point) = if self.is_compatible_with(other) {
(self.scale, self.zero_point)
} else {
compute_output_quantization_params(
(self.scale, self.zero_point),
(other.scale, other.zero_point),
QuantizedOperation::Add,
)?
};
let result_data = if self.is_compatible_with(other) {
Zip::from(&self.data)
.and(&other.data)
.map_collect(|&a, &b| {
let sum = QuantizableInteger::to_i32(&a)
.saturating_add(QuantizableInteger::to_i32(&b))
.saturating_sub(self.zero_point);
Q::from_i32_clamped(sum)
})
} else {
Zip::from(&self.data)
.and(&other.data)
.map_collect(|&a, &b| {
let a_fp =
(QuantizableInteger::to_i32(&a) - self.zero_point) as f32 * self.scale;
let b_fp =
(QuantizableInteger::to_i32(&b) - other.zero_point) as f32 * other.scale;
let sum_fp = a_fp + b_fp;
let quantized = (sum_fp / result_scale).round() as i32 + result_zero_point;
Q::from_i32_clamped(quantized)
})
};
Ok(QuantizedTensor::new(
result_data,
result_scale,
result_zero_point,
self.device.clone(),
))
}
fn qsub(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>> {
if self.shape() != other.shape() {
return Err(RusTorchError::ShapeMismatch {
expected: self.shape().to_vec(),
actual: other.shape().to_vec(),
});
}
let (result_scale, result_zero_point) = if self.is_compatible_with(other) {
(self.scale, self.zero_point)
} else {
compute_output_quantization_params(
(self.scale, self.zero_point),
(other.scale, other.zero_point),
QuantizedOperation::Sub,
)?
};
let result_data = if self.is_compatible_with(other) {
Zip::from(&self.data)
.and(&other.data)
.map_collect(|&a, &b| {
let diff = QuantizableInteger::to_i32(&a)
.saturating_sub(QuantizableInteger::to_i32(&b))
.saturating_add(self.zero_point);
Q::from_i32_clamped(diff)
})
} else {
Zip::from(&self.data)
.and(&other.data)
.map_collect(|&a, &b| {
let a_fp =
(QuantizableInteger::to_i32(&a) - self.zero_point) as f32 * self.scale;
let b_fp =
(QuantizableInteger::to_i32(&b) - other.zero_point) as f32 * other.scale;
let diff_fp = a_fp - b_fp;
let quantized = (diff_fp / result_scale).round() as i32 + result_zero_point;
Q::from_i32_clamped(quantized)
})
};
Ok(QuantizedTensor::new(
result_data,
result_scale,
result_zero_point,
self.device.clone(),
))
}
fn qmul(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>> {
if self.shape() != other.shape() {
return Err(RusTorchError::ShapeMismatch {
expected: self.shape().to_vec(),
actual: other.shape().to_vec(),
});
}
let result_scale = self.scale * other.scale;
let result_zero_point = 0;
let result_data = Zip::from(&self.data)
.and(&other.data)
.map_collect(|&a, &b| {
let a_adjusted = QuantizableInteger::to_i32(&a) - self.zero_point;
let b_adjusted = QuantizableInteger::to_i32(&b) - other.zero_point;
let product = a_adjusted * b_adjusted;
let scaled_product =
(product as f32 / (self.scale * other.scale / result_scale)).round() as i32;
Q::from_i32_clamped(scaled_product)
});
Ok(QuantizedTensor::new(
result_data,
result_scale,
result_zero_point,
self.device.clone(),
))
}
fn qmatmul(&self, other: &QuantizedTensor<Q>) -> RusTorchResult<QuantizedTensor<Q>> {
let self_shape = self.shape();
let other_shape = other.shape();
if self_shape.len() < 2 || other_shape.len() < 2 {
return Err(RusTorchError::TensorOp {
message: "Matrix multiplication requires at least 2D tensors".to_string(),
source: None,
});
}
let self_cols = self_shape[self_shape.len() - 1];
let other_rows = other_shape[other_shape.len() - 2];
if self_cols != other_rows {
return Err(RusTorchError::ShapeMismatch {
expected: vec![self_cols],
actual: vec![other_rows],
});
}
if self_shape.len() != 2 || other_shape.len() != 2 {
return Err(RusTorchError::TensorOp {
message: "Only 2D matrix multiplication currently supported".to_string(),
source: None,
});
}
let m = self_shape[0];
let k = self_shape[1];
let n = other_shape[1];
let result_scale = self.scale * other.scale;
let result_zero_point = 0;
let mut result_data = ArrayD::zeros(vec![m, n]);
for i in 0..m {
for j in 0..n {
let mut sum = 0i64;
for l in 0..k {
let a_val = QuantizableInteger::to_i32(&self.data[[i, l]]) - self.zero_point;
let b_val = QuantizableInteger::to_i32(&other.data[[l, j]]) - other.zero_point;
sum += (a_val as i64) * (b_val as i64);
}
let fp_result = sum as f32 * result_scale;
let quantized = (fp_result / result_scale).round() as i32;
result_data[[i, j]] = Q::from_i32_clamped(quantized);
}
}
Ok(QuantizedTensor::new(
result_data,
result_scale,
result_zero_point,
self.device.clone(),
))
}
fn qrelu(&self) -> RusTorchResult<QuantizedTensor<Q>> {
let zero_quantized = Q::from_i32_clamped(self.zero_point);
let result_data = self.data.mapv(|val| {
if QuantizableInteger::to_i32(&val) > self.zero_point {
val
} else {
zero_quantized
}
});
Ok(QuantizedTensor::new(
result_data,
self.scale,
self.zero_point,
self.device.clone(),
))
}
fn qadd_scalar(&self, scalar: f32) -> RusTorchResult<QuantizedTensor<Q>> {
let scalar_quantized = (scalar / self.scale).round() as i32 + self.zero_point;
let scalar_clamped = Q::from_i32_clamped(scalar_quantized);
let result_data = self.data.mapv(|val| {
let sum = QuantizableInteger::to_i32(&val)
.saturating_add(QuantizableInteger::to_i32(&scalar_clamped))
.saturating_sub(self.zero_point);
Q::from_i32_clamped(sum)
});
Ok(QuantizedTensor::new(
result_data,
self.scale,
self.zero_point,
self.device.clone(),
))
}
fn qmul_scalar(&self, scalar: f32) -> RusTorchResult<QuantizedTensor<Q>> {
let new_scale = self.scale * scalar.abs();
let result_data = self.data.mapv(|val| {
let adjusted = QuantizableInteger::to_i32(&val) - self.zero_point;
let scaled = (adjusted as f32 * scalar).round() as i32;
Q::from_i32_clamped(scaled)
});
Ok(QuantizedTensor::new(
result_data,
new_scale,
0, self.device.clone(),
))
}
}
pub trait DequantizeOps<Q: QuantizableInteger> {
fn dequantize_f32(&self) -> ArrayD<f32>;
fn dequantize_f64(&self) -> ArrayD<f64>;
fn dequantize_partial(
&self,
new_scale: f32,
new_zero_point: i32,
) -> RusTorchResult<QuantizedTensor<Q>>;
}
impl<Q: QuantizableInteger> DequantizeOps<Q> for QuantizedTensor<Q> {
fn dequantize_f32(&self) -> ArrayD<f32> {
self.data.mapv(|q_val| {
(QuantizableInteger::to_i32(&q_val) - self.zero_point) as f32 * self.scale
})
}
fn dequantize_f64(&self) -> ArrayD<f64> {
self.data.mapv(|q_val| {
((QuantizableInteger::to_i32(&q_val) - self.zero_point) as f32 * self.scale) as f64
})
}
fn dequantize_partial(
&self,
new_scale: f32,
new_zero_point: i32,
) -> RusTorchResult<QuantizedTensor<Q>> {
let result_data = self.data.mapv(|q_val| {
let fp_val = (QuantizableInteger::to_i32(&q_val) - self.zero_point) as f32 * self.scale;
let new_q_val = (fp_val / new_scale).round() as i32 + new_zero_point;
Q::from_i32_clamped(new_q_val)
});
Ok(QuantizedTensor::new(
result_data,
new_scale,
new_zero_point,
self.device.clone(),
))
}
}
#[derive(Debug, Clone, Copy)]
enum QuantizedOperation {
Add,
Sub,
Mul,
Div,
}
fn compute_output_quantization_params(
params1: (f32, i32),
params2: (f32, i32),
operation: QuantizedOperation,
) -> RusTorchResult<(f32, i32)> {
let (scale1, zp1) = params1;
let (scale2, zp2) = params2;
match operation {
QuantizedOperation::Add | QuantizedOperation::Sub => {
let result_scale = scale1.max(scale2);
let result_zero_point = 0;
Ok((result_scale, result_zero_point))
}
QuantizedOperation::Mul => {
let result_scale = scale1 * scale2;
let result_zero_point = 0;
Ok((result_scale, result_zero_point))
}
QuantizedOperation::Div => {
let result_scale = if scale2 != 0.0 {
scale1 / scale2
} else {
scale1
};
let result_zero_point = 0;
Ok((result_scale, result_zero_point))
}
}
}
pub fn qlinear<Q: QuantizableInteger>(
input: &QuantizedTensor<Q>,
weight: &QuantizedTensor<Q>,
bias: Option<&QuantizedTensor<Q>>,
) -> RusTorchResult<QuantizedTensor<Q>> {
let output = input.qmatmul(weight)?;
if let Some(bias_tensor) = bias {
output.qadd(bias_tensor)
} else {
Ok(output)
}
}
pub fn qconv1d<Q: QuantizableInteger>(
input: &QuantizedTensor<Q>,
weight: &QuantizedTensor<Q>,
bias: Option<&QuantizedTensor<Q>>,
stride: usize,
padding: usize,
) -> RusTorchResult<QuantizedTensor<Q>> {
let input_shape = input.shape();
let weight_shape = weight.shape();
if input_shape.len() != 3 || weight_shape.len() != 3 {
return Err(RusTorchError::TensorOp {
message: "Expected 3D tensors for 1D convolution [batch, channels, length]".to_string(),
source: None,
});
}
let batch_size = input_shape[0];
let in_channels = input_shape[1];
let input_length = input_shape[2];
let out_channels = weight_shape[0];
let kernel_size = weight_shape[2];
if weight_shape[1] != in_channels {
return Err(RusTorchError::ShapeMismatch {
expected: vec![in_channels],
actual: vec![weight_shape[1]],
});
}
let output_length = (input_length + 2 * padding - kernel_size) / stride + 1;
let result_scale = input.scale * weight.scale;
let result_shape = vec![batch_size, out_channels, output_length];
let result_data = ArrayD::zeros(result_shape);
Ok(QuantizedTensor::new(
result_data,
result_scale,
0,
input.device.clone(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::device::Device;
use ndarray::Array2;
#[test]
fn test_quantized_addition() {
let data1 = Array2::from_shape_vec((2, 2), vec![10i8, 20, 30, 40])
.unwrap()
.into_dyn();
let data2 = Array2::from_shape_vec((2, 2), vec![5i8, 10, 15, 20])
.unwrap()
.into_dyn();
let qtensor1 = QuantizedTensor::new(data1, 0.1, 0, Device::default());
let qtensor2 = QuantizedTensor::new(data2, 0.1, 0, Device::default());
let result = qtensor1.qadd(&qtensor2).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.scale, 0.1);
assert_eq!(result.zero_point, 0);
}
#[test]
fn test_quantized_multiplication() {
let data1 = Array2::from_shape_vec((2, 2), vec![2i8, 3, 4, 5])
.unwrap()
.into_dyn();
let data2 = Array2::from_shape_vec((2, 2), vec![3i8, 4, 5, 6])
.unwrap()
.into_dyn();
let qtensor1 = QuantizedTensor::new(data1, 0.1, 0, Device::default());
let qtensor2 = QuantizedTensor::new(data2, 0.2, 0, Device::default());
let result = qtensor1.qmul(&qtensor2).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.scale, 0.1 * 0.2); assert_eq!(result.zero_point, 0);
}
#[test]
fn test_quantized_relu() {
let data = Array2::from_shape_vec((2, 2), vec![-10i8, -5, 5, 10])
.unwrap()
.into_dyn();
let qtensor = QuantizedTensor::new(data, 0.1, 0, Device::default());
let result = qtensor.qrelu().unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.scale, 0.1);
assert_eq!(result.zero_point, 0);
}
#[test]
fn test_quantized_matmul() {
let data1 = Array2::from_shape_vec((2, 3), vec![1i8, 2, 3, 4, 5, 6])
.unwrap()
.into_dyn();
let data2 = Array2::from_shape_vec((3, 2), vec![7i8, 8, 9, 10, 11, 12])
.unwrap()
.into_dyn();
let qtensor1 = QuantizedTensor::new(data1, 0.1, 0, Device::default());
let qtensor2 = QuantizedTensor::new(data2, 0.1, 0, Device::default());
let result = qtensor1.qmatmul(&qtensor2).unwrap();
assert_eq!(result.shape(), &[2, 2]); assert_eq!(result.scale, 0.1 * 0.1);
}
#[test]
fn test_scalar_operations() {
let data = Array2::from_shape_vec((2, 2), vec![10i8, 20, 30, 40])
.unwrap()
.into_dyn();
let qtensor = QuantizedTensor::new(data, 0.1, 0, Device::default());
let result_add = qtensor.qadd_scalar(5.0).unwrap();
assert_eq!(result_add.scale, 0.1);
let result_mul = qtensor.qmul_scalar(2.0).unwrap();
assert_eq!(result_mul.scale, 0.1 * 2.0);
}
#[test]
fn test_dequantization() {
let data = Array2::from_shape_vec((2, 2), vec![10i8, 20, 30, 40])
.unwrap()
.into_dyn();
let qtensor = QuantizedTensor::new(data, 0.1, 0, Device::default());
let dequantized_f32 = qtensor.dequantize_f32();
let dequantized_f64 = qtensor.dequantize_f64();
assert_eq!(dequantized_f32.shape(), &[2, 2]);
assert_eq!(dequantized_f64.shape(), &[2, 2]);
assert_eq!(dequantized_f32[[0, 0]], 1.0); assert_eq!(dequantized_f32[[0, 1]], 2.0); }
#[test]
fn test_qlinear() {
let input_data = Array2::from_shape_vec((1, 3), vec![1i8, 2, 3])
.unwrap()
.into_dyn();
let weight_data = Array2::from_shape_vec((3, 2), vec![1i8, 2, 3, 4, 5, 6])
.unwrap()
.into_dyn();
let input = QuantizedTensor::new(input_data, 0.1, 0, Device::default());
let weight = QuantizedTensor::new(weight_data, 0.1, 0, Device::default());
let result = qlinear(&input, &weight, None).unwrap();
assert_eq!(result.shape(), &[1, 2]);
}
}