use crate::quant::QuantTensor;
use crate::quant::format::QuantFormat;
use crate::quant::traits::{DequantOps, QuantMatmulOps};
use half::f16;
use numr::dtype::DType;
use numr::ops::MatmulOps;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
use numr::tensor::Tensor;
fn setup() -> (CpuClient, CpuDevice) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(client, device)
}
#[test]
fn test_quant_matmul_q4_0_basic() {
let (client, device) = setup();
let act = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 32], &[1, 32], &device);
let mut block = [0u8; 18];
block[0..2].copy_from_slice(&f16::from_f32(2.0).to_le_bytes());
block[2..18].fill(0x99);
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q4_0, &[1, 32], &device)
.unwrap();
let result = client.quant_matmul(&act, &qt).unwrap();
assert_eq!(result.shape(), &[1, 1]);
let data = result.to_vec::<f32>();
assert!(
(data[0] - 64.0).abs() < 0.5,
"expected ~64.0, got {}",
data[0]
);
}
#[test]
fn test_quant_matmul_matches_dequant_matmul() {
let (client, device) = setup();
let act_data: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
let act = Tensor::<CpuRuntime>::from_slice(&act_data, &[2, 32], &device);
let mut weight_bytes = Vec::new();
for row in 0..3 {
let mut block = [0u8; 34];
block[0..2].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
block[2..34].fill((row + 1) as u8);
weight_bytes.extend_from_slice(&block);
}
let qt =
QuantTensor::<CpuRuntime>::from_bytes(&weight_bytes, QuantFormat::Q8_0, &[3, 32], &device)
.unwrap();
let result_qm = client.quant_matmul(&act, &qt).unwrap();
let dequant_w = client.dequantize(&qt, DType::F32).unwrap();
let dequant_w_t = dequant_w.transpose(0isize, 1isize).unwrap();
let result_dm = MatmulOps::matmul(&client, &act, &dequant_w_t).unwrap();
assert_eq!(result_qm.shape(), result_dm.shape());
let qm_data = result_qm.to_vec::<f32>();
let dm_data = result_dm.to_vec::<f32>();
for (i, (&a, &b)) in qm_data.iter().zip(dm_data.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-2,
"mismatch at index {}: quant_matmul={}, dequant+matmul={}",
i,
a,
b
);
}
}
#[test]
fn test_quant_matmul_dim_mismatch() {
let (client, device) = setup();
let act = Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; 64], &[2, 32], &device);
let block = vec![0u8; 2 * 34]; let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q8_0, &[1, 64], &device)
.unwrap();
let result = client.quant_matmul(&act, &qt);
assert!(result.is_err());
}
#[test]
fn test_quant_matmul_q2k_basic() {
let (client, device) = setup();
let act = Tensor::<CpuRuntime>::from_slice(&vec![1.0f32; 256], &[1, 256], &device);
let block = vec![0u8; 84];
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q2K, &[1, 256], &device)
.unwrap();
let result = client.quant_matmul(&act, &qt).unwrap();
assert_eq!(result.shape(), &[1, 1]);
let data = result.to_vec::<f32>();
assert!(data[0].abs() < 1e-5, "expected ~0.0, got {}", data[0]);
}
#[test]
#[allow(clippy::needless_range_loop)]
fn test_quant_matmul_q2k_matches_dequant_matmul() {
let (client, device) = setup();
let k = 512;
let n = 2;
let m = 2;
let act_data: Vec<f32> = (0..m * k).map(|i| ((i % 17) as f32 - 8.0) * 0.01).collect();
let act = Tensor::<CpuRuntime>::from_slice(&act_data, &[m, k], &device);
let mut weight_bytes = Vec::new();
for row in 0..n {
for blk in 0..2 {
let mut block = [0u8; 84];
for i in 0..16 {
let s = ((i + row + blk) % 15 + 1) as u8; let m_val = ((i + row * 3 + blk) % 10) as u8; block[i] = s | (m_val << 4);
}
for i in 0..64 {
block[16 + i] = ((i + row * 7 + blk * 3) % 256) as u8;
}
block[80..82].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
block[82..84].copy_from_slice(&f16::from_f32(0.1).to_le_bytes());
weight_bytes.extend_from_slice(&block);
}
}
let qt =
QuantTensor::<CpuRuntime>::from_bytes(&weight_bytes, QuantFormat::Q2K, &[n, k], &device)
.unwrap();
let result_qm = client.quant_matmul(&act, &qt).unwrap();
let dequant_w = client.dequantize(&qt, DType::F32).unwrap();
let dequant_w_t = dequant_w.transpose(0isize, 1isize).unwrap();
let result_dm = MatmulOps::matmul(&client, &act, &dequant_w_t).unwrap();
assert_eq!(result_qm.shape(), result_dm.shape());
let qm_data = result_qm.to_vec::<f32>();
let dm_data = result_dm.to_vec::<f32>();
for (i, (&a, &b)) in qm_data.iter().zip(dm_data.iter()).enumerate() {
let tol = 0.05 * b.abs().max(1.0);
assert!(
(a - b).abs() < tol,
"Q2K mismatch at index {}: quant_matmul={}, dequant+matmul={}, tol={}",
i,
a,
b,
tol
);
}
}
#[test]
#[allow(clippy::needless_range_loop)]
fn test_quant_matmul_q3k_matches_dequant_matmul() {
let (client, device) = setup();
let k = 512;
let n = 3;
let m = 2;
let act_data: Vec<f32> = (0..m * k).map(|i| ((i % 13) as f32 - 6.0) * 0.01).collect();
let act = Tensor::<CpuRuntime>::from_slice(&act_data, &[m, k], &device);
let mut weight_bytes = Vec::new();
for row in 0..n {
for blk in 0..2 {
let mut block = [0u8; 110];
for i in 0..32 {
block[i] = ((i * 7 + row * 13 + blk * 5) % 256) as u8;
}
for i in 0..64 {
block[32 + i] = ((i * 11 + row * 3 + blk * 7) % 256) as u8;
}
for i in 0..12 {
block[96 + i] = ((i * 5 + row * 9 + blk) % 256) as u8;
}
block[108..110].copy_from_slice(&f16::from_f32(0.3).to_le_bytes());
weight_bytes.extend_from_slice(&block);
}
}
let qt =
QuantTensor::<CpuRuntime>::from_bytes(&weight_bytes, QuantFormat::Q3K, &[n, k], &device)
.unwrap();
let result_qm = client.quant_matmul(&act, &qt).unwrap();
let dequant_w = client.dequantize(&qt, DType::F32).unwrap();
let dequant_w_t = dequant_w.transpose(0isize, 1isize).unwrap();
let result_dm = MatmulOps::matmul(&client, &act, &dequant_w_t).unwrap();
assert_eq!(result_qm.shape(), result_dm.shape());
let qm_data = result_qm.to_vec::<f32>();
let dm_data = result_dm.to_vec::<f32>();
for (i, (&a, &b)) in qm_data.iter().zip(dm_data.iter()).enumerate() {
let tol = 0.05 * b.abs().max(1.0);
assert!(
(a - b).abs() < tol,
"Q3K mismatch at index {}: quant_matmul={}, dequant+matmul={}, tol={}",
i,
a,
b,
tol
);
}
}