use super::super::operations::QuantizedOps;
use super::super::types::{QuantizableInteger, QuantizedTensor};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::device::Device;
use ndarray::ArrayD;
pub struct CpuQuantizedOps;
impl CpuQuantizedOps {
pub fn qadd_simd_i8(
a: &[i8],
b: &[i8],
output: &mut [i8],
scale_a: f32,
scale_b: f32,
zero_point_a: i32,
zero_point_b: i32,
output_scale: f32,
output_zero_point: i32,
) -> RusTorchResult<()> {
if a.len() != b.len() || a.len() != output.len() {
return Err(RusTorchError::TensorOp {
message: "Array lengths must match for quantized addition".to_string(),
source: None,
});
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return Self::qadd_avx2_i8(
a,
b,
output,
scale_a,
scale_b,
zero_point_a,
zero_point_b,
output_scale,
output_zero_point,
);
}
}
Self::qadd_scalar_i8(
a,
b,
output,
scale_a,
scale_b,
zero_point_a,
zero_point_b,
output_scale,
output_zero_point,
)
}
fn qadd_scalar_i8(
a: &[i8],
b: &[i8],
output: &mut [i8],
scale_a: f32,
scale_b: f32,
zero_point_a: i32,
zero_point_b: i32,
output_scale: f32,
output_zero_point: i32,
) -> RusTorchResult<()> {
for (((&a_val, &b_val), out_val), _) in
a.iter().zip(b.iter()).zip(output.iter_mut()).zip(0..)
{
let a_fp = (a_val as i32 - zero_point_a) as f32 * scale_a;
let b_fp = (b_val as i32 - zero_point_b) as f32 * scale_b;
let sum_fp = a_fp + b_fp;
let quantized = (sum_fp / output_scale).round() as i32 + output_zero_point;
*out_val = quantized.clamp(-128, 127) as i8;
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
fn qadd_avx2_i8(
a: &[i8],
b: &[i8],
output: &mut [i8],
scale_a: f32,
scale_b: f32,
zero_point_a: i32,
zero_point_b: i32,
output_scale: f32,
output_zero_point: i32,
) -> RusTorchResult<()> {
#[cfg(target_feature = "avx2")]
unsafe {
use std::arch::x86_64::*;
let len = a.len();
let simd_len = len & !7;
let scale_a_vec = _mm256_set1_ps(scale_a);
let scale_b_vec = _mm256_set1_ps(scale_b);
let inv_output_scale = _mm256_set1_ps(1.0 / output_scale);
let zero_point_a_vec = _mm256_set1_ps(zero_point_a as f32);
let zero_point_b_vec = _mm256_set1_ps(zero_point_b as f32);
let output_zero_point_vec = _mm256_set1_ps(output_zero_point as f32);
for i in (0..simd_len).step_by(8) {
let a_i8 = _mm_loadl_epi64(a.as_ptr().add(i) as *const __m128i);
let b_i8 = _mm_loadl_epi64(b.as_ptr().add(i) as *const __m128i);
let a_i32 = _mm256_cvtepi8_epi32(a_i8);
let b_i32 = _mm256_cvtepi8_epi32(b_i8);
let a_f32 = _mm256_cvtepi32_ps(a_i32);
let b_f32 = _mm256_cvtepi32_ps(b_i32);
let a_dequant = _mm256_mul_ps(_mm256_sub_ps(a_f32, zero_point_a_vec), scale_a_vec);
let b_dequant = _mm256_mul_ps(_mm256_sub_ps(b_f32, zero_point_b_vec), scale_b_vec);
let sum_f32 = _mm256_add_ps(a_dequant, b_dequant);
let quantized_f32 = _mm256_add_ps(
_mm256_mul_ps(sum_f32, inv_output_scale),
output_zero_point_vec,
);
let quantized_i32 = _mm256_cvtps_epi32(quantized_f32);
let quantized_i16 = _mm256_packs_epi32(quantized_i32, _mm256_setzero_si256());
let quantized_i8 = _mm_packs_epi16(
_mm256_extracti128_si256(quantized_i16, 0),
_mm256_extracti128_si256(quantized_i16, 1),
);
_mm_storel_epi64(output.as_mut_ptr().add(i) as *mut __m128i, quantized_i8);
}
for i in simd_len..len {
let a_fp = (a[i] as i32 - zero_point_a) as f32 * scale_a;
let b_fp = (b[i] as i32 - zero_point_b) as f32 * scale_b;
let sum_fp = a_fp + b_fp;
let quantized = (sum_fp / output_scale).round() as i32 + output_zero_point;
output[i] = quantized.clamp(-128, 127) as i8;
}
}
#[cfg(not(target_feature = "avx2"))]
{
Self::qadd_scalar_i8(
a,
b,
output,
scale_a,
scale_b,
zero_point_a,
zero_point_b,
output_scale,
output_zero_point,
)?;
}
Ok(())
}
pub fn qmatmul_i8_optimized(
a: &ArrayD<i8>,
b: &ArrayD<i8>,
scale_a: f32,
scale_b: f32,
zero_point_a: i32,
zero_point_b: i32,
) -> RusTorchResult<(ArrayD<i8>, f32, i32)> {
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(RusTorchError::TensorOp {
message: "Only 2D matrices supported for optimized qmatmul".to_string(),
source: None,
});
}
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
if b_shape[0] != k {
return Err(RusTorchError::ShapeMismatch {
expected: vec![k],
actual: vec![b_shape[0]],
});
}
let mut output = ArrayD::zeros(vec![m, n]);
let output_scale = scale_a * scale_b;
let output_zero_point = 0;
const BLOCK_SIZE: usize = 64;
for ii in (0..m).step_by(BLOCK_SIZE) {
for jj in (0..n).step_by(BLOCK_SIZE) {
for kk in (0..k).step_by(BLOCK_SIZE) {
let i_end = (ii + BLOCK_SIZE).min(m);
let j_end = (jj + BLOCK_SIZE).min(n);
let k_end = (kk + BLOCK_SIZE).min(k);
for i in ii..i_end {
for j in jj..j_end {
let mut sum = if kk == 0 {
0i32
} else {
let existing = output[[i, j]] as i32;
(existing as f32 / output_scale).round() as i32 - output_zero_point
};
for l in kk..k_end {
let a_val = a[[i, l]] as i32 - zero_point_a;
let b_val = b[[l, j]] as i32 - zero_point_b;
sum += a_val * b_val;
}
if kk + BLOCK_SIZE >= k {
let fp_result = sum as f32 * output_scale;
let quantized_result = (fp_result / output_scale).round() as i32;
output[[i, j]] = quantized_result.clamp(-128, 127) as i8;
} else {
let intermediate = sum + output_zero_point;
output[[i, j]] = intermediate.clamp(-128, 127) as i8;
}
}
}
}
}
}
Ok((output, output_scale, output_zero_point))
}
}
pub fn qmatmul_cpu<Q: QuantizableInteger>(
a: &QuantizedTensor<Q>,
b: &QuantizedTensor<Q>,
) -> RusTorchResult<QuantizedTensor<Q>> {
a.qmatmul(b)
}
pub fn qconv2d_cpu<Q: QuantizableInteger>(
input: &QuantizedTensor<Q>,
weight: &QuantizedTensor<Q>,
bias: Option<&QuantizedTensor<Q>>,
stride: (usize, usize),
padding: (usize, usize),
) -> RusTorchResult<QuantizedTensor<Q>> {
let output_shape = vec![
input.shape()[0], weight.shape()[0], (input.shape()[2] + 2 * padding.0 - weight.shape()[2]) / stride.0 + 1, (input.shape()[3] + 2 * padding.1 - weight.shape()[3]) / stride.1 + 1, ];
let result_scale = input.scale * weight.scale;
let result_data = ArrayD::zeros(output_shape);
Ok(QuantizedTensor::new(
result_data,
result_scale,
0,
input.device.clone(),
))
}
pub mod optimized_ops {
use super::*;
pub fn check_cpu_features() -> CpuFeatures {
CpuFeatures {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
avx2: cfg!(target_feature = "avx2") || std::arch::is_x86_feature_detected!("avx2"),
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
avx2: false,
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
avx512: cfg!(target_feature = "avx512f")
|| std::arch::is_x86_feature_detected!("avx512f"),
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
avx512: false,
neon: cfg!(target_feature = "neon"),
}
}
pub fn qlinear_cpu_optimized<Q: QuantizableInteger>(
input: &QuantizedTensor<Q>,
weight: &QuantizedTensor<Q>,
bias: Option<&QuantizedTensor<Q>>,
) -> RusTorchResult<QuantizedTensor<Q>> {
let features = check_cpu_features();
if features.avx2 {
qmatmul_cpu(input, weight)
} else {
qmatmul_cpu(input, weight)
}
}
pub fn batch_qoperations<Q: QuantizableInteger>(
operations: &[QuantizedOperation<Q>],
) -> RusTorchResult<Vec<QuantizedTensor<Q>>> {
let mut results = Vec::with_capacity(operations.len());
for op in operations {
match op {
QuantizedOperation::MatMul(a, b) => {
results.push(qmatmul_cpu(a, b)?);
}
QuantizedOperation::Add(a, b) => {
results.push(a.qadd(b)?);
} }
}
Ok(results)
}
}
#[derive(Debug, Clone)]
pub struct CpuFeatures {
pub avx2: bool,
pub avx512: bool,
pub neon: bool,
}
#[derive(Debug)]
pub enum QuantizedOperation<'a, Q: QuantizableInteger> {
MatMul(&'a QuantizedTensor<Q>, &'a QuantizedTensor<Q>),
Add(&'a QuantizedTensor<Q>, &'a QuantizedTensor<Q>),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::device::Device;
use ndarray::Array2;
#[test]
fn test_cpu_features() {
let features = optimized_ops::check_cpu_features();
println!("CPU Features: {:?}", features);
}
#[test]
fn test_qadd_simd() {
let a = vec![10i8, 20, 30, 40];
let b = vec![5i8, 10, 15, 20];
let mut output = vec![0i8; 4];
let result = CpuQuantizedOps::qadd_simd_i8(&a, &b, &mut output, 0.1, 0.1, 0, 0, 0.1, 0);
assert!(result.is_ok());
assert!(output.iter().all(|&x| x != 0));
}
#[test]
fn test_qmatmul_i8_optimized() {
let a_data = Array2::from_shape_vec((2, 3), vec![1i8, 2, 3, 4, 5, 6])
.unwrap()
.into_dyn();
let b_data = Array2::from_shape_vec((3, 2), vec![7i8, 8, 9, 10, 11, 12])
.unwrap()
.into_dyn();
let result = CpuQuantizedOps::qmatmul_i8_optimized(&a_data, &b_data, 0.1, 0.1, 0, 0);
assert!(result.is_ok());
let (output, scale, zero_point) = result.unwrap();
assert_eq!(output.shape(), &[2, 2]);
assert_eq!(scale, 0.1 * 0.1);
assert_eq!(zero_point, 0);
}
#[test]
fn test_qconv2d_cpu() {
let input_data = ArrayD::<i8>::zeros(vec![1, 3, 32, 32]);
let weight_data = ArrayD::<i8>::zeros(vec![16, 3, 3, 3]);
let input = QuantizedTensor::new(input_data, 0.1, 0, Device::default());
let weight = QuantizedTensor::new(weight_data, 0.1, 0, Device::default());
let result = qconv2d_cpu(&input, &weight, None, (1, 1), (1, 1));
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape()[0], 1); assert_eq!(output.shape()[1], 16); }
}