use crate::error::{Error, Result};
use crate::quant::traits::DequantOps;
use crate::quant::{QuantFormat, QuantTensor};
use numr::dtype::DType;
use numr::ops::TypeConversionOps;
use numr::runtime::cpu::{CpuClient, CpuRuntime};
use numr::tensor::Tensor;
use super::kernels::{dequant, nf4};
impl DequantOps<CpuRuntime> for CpuClient {
fn nf4_dequant(
&self,
nf4_data: &Tensor<CpuRuntime>,
absmax: &Tensor<CpuRuntime>,
blocksize: usize,
) -> Result<Tensor<CpuRuntime>> {
if nf4_data.dtype() != DType::U8 {
return Err(Error::QuantError {
reason: format!("nf4_dequant data must be U8, got {:?}", nf4_data.dtype()),
});
}
let data = unsafe { nf4_data.storage().as_host_slice::<u8>() };
let abs = unsafe { absmax.storage().as_host_slice::<f32>() };
let n = data.len() * 2;
let mut output = vec![0.0f32; n];
nf4::nf4_dequant_f32(data, abs, blocksize, &mut output);
Ok(Tensor::<CpuRuntime>::from_slice(
&output,
&[n],
nf4_data.device(),
))
}
fn nf4_gemm(
&self,
input: &Tensor<CpuRuntime>,
nf4_weight: &Tensor<CpuRuntime>,
absmax: &Tensor<CpuRuntime>,
n_out: usize,
k: usize,
blocksize: usize,
) -> Result<Tensor<CpuRuntime>> {
if input.dtype() != DType::F32 {
return Err(Error::QuantError {
reason: format!("nf4_gemm input must be F32, got {:?}", input.dtype()),
});
}
let in_shape = input.shape();
let m: usize = in_shape.iter().product::<usize>() / k;
let inp = unsafe { input.storage().as_host_slice::<f32>() };
let wt = unsafe { nf4_weight.storage().as_host_slice::<u8>() };
let abs = unsafe { absmax.storage().as_host_slice::<f32>() };
let mut output = vec![0.0f32; m * n_out];
nf4::nf4_gemm_f32(inp, wt, abs, &mut output, m, k, n_out, blocksize);
let mut out_shape = in_shape[..in_shape.len() - 1].to_vec();
out_shape.push(n_out);
Ok(Tensor::<CpuRuntime>::from_slice(
&output,
&out_shape,
input.device(),
))
}
fn dequantize(
&self,
qt: &QuantTensor<CpuRuntime>,
target_dtype: DType,
) -> Result<Tensor<CpuRuntime>> {
if !matches!(
target_dtype,
DType::F32 | DType::F16 | DType::BF16 | DType::F64
) {
return Err(Error::QuantError {
reason: format!("dequantize target must be float, got {:?}", target_dtype),
});
}
let numel = qt.numel();
let block_bytes = unsafe { qt.storage().as_host_slice::<u8>() };
let mut f32_output = vec![0.0f32; numel];
match qt.format() {
QuantFormat::Q4_0 => dequant::dequant_q4_0(block_bytes, &mut f32_output),
QuantFormat::Q4_1 => dequant::dequant_q4_1(block_bytes, &mut f32_output),
QuantFormat::Q5_0 => dequant::dequant_q5_0(block_bytes, &mut f32_output),
QuantFormat::Q5_1 => dequant::dequant_q5_1(block_bytes, &mut f32_output),
QuantFormat::Q8_0 => dequant::dequant_q8_0(block_bytes, &mut f32_output),
QuantFormat::Q8_1 => dequant::dequant_q8_1(block_bytes, &mut f32_output),
QuantFormat::Q2K => dequant::dequant_q2k(block_bytes, &mut f32_output),
QuantFormat::Q3K => dequant::dequant_q3k(block_bytes, &mut f32_output),
QuantFormat::Q4K => dequant::dequant_q4k(block_bytes, &mut f32_output),
QuantFormat::Q5K => dequant::dequant_q5k(block_bytes, &mut f32_output),
QuantFormat::Q6K => dequant::dequant_q6k(block_bytes, &mut f32_output),
QuantFormat::Q8K => dequant::dequant_q8k(block_bytes, &mut f32_output),
QuantFormat::IQ4NL => dequant::dequant_iq4_nl(block_bytes, &mut f32_output),
QuantFormat::IQ4XS => dequant::dequant_iq4_xs(block_bytes, &mut f32_output),
QuantFormat::IQ2XXS => dequant::dequant_iq2_xxs(block_bytes, &mut f32_output),
QuantFormat::IQ2XS => dequant::dequant_iq2_xs(block_bytes, &mut f32_output),
QuantFormat::IQ2S => dequant::dequant_iq2_s(block_bytes, &mut f32_output),
QuantFormat::IQ3XXS => dequant::dequant_iq3_xxs(block_bytes, &mut f32_output),
QuantFormat::IQ3S => dequant::dequant_iq3_s(block_bytes, &mut f32_output),
QuantFormat::IQ1S => dequant::dequant_iq1_s(block_bytes, &mut f32_output),
QuantFormat::IQ1M => dequant::dequant_iq1_m(block_bytes, &mut f32_output),
QuantFormat::TQ1_0 => dequant::dequant_tq1_0(block_bytes, &mut f32_output),
QuantFormat::TQ2_0 => dequant::dequant_tq2_0(block_bytes, &mut f32_output),
}
let f32_tensor = Tensor::<CpuRuntime>::from_slice(&f32_output, qt.shape(), qt.device());
if target_dtype == DType::F32 {
Ok(f32_tensor)
} else {
self.cast(&f32_tensor, target_dtype).map_err(Error::Numr)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
use numr::runtime::cpu::CpuDevice;
fn setup() -> (CpuClient, CpuDevice) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(client, device)
}
#[test]
fn test_dequant_q4_0_roundtrip() {
let (client, device) = setup();
let mut block = [0u8; 18];
block[0..2].copy_from_slice(&f16::from_f32(2.0).to_le_bytes());
block[2..18].fill(0x99);
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q4_0, &[32], &device)
.unwrap();
let result = client.dequantize(&qt, DType::F32).unwrap();
assert_eq!(result.shape(), &[32]);
assert_eq!(result.dtype(), DType::F32);
let data = result.to_vec::<f32>();
for &v in &data {
assert!((v - 2.0).abs() < 0.01, "expected 2.0, got {}", v);
}
}
#[test]
fn test_dequant_q8_0_roundtrip() {
let (client, device) = setup();
let mut block = [0u8; 34];
block[0..2].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
block[2..34].fill(6);
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q8_0, &[32], &device)
.unwrap();
let result = client.dequantize(&qt, DType::F32).unwrap();
let data = result.to_vec::<f32>();
for &v in &data {
assert!((v - 3.0).abs() < 0.01, "expected 3.0, got {}", v);
}
}
#[test]
fn test_dequant_q4k_basic() {
let (client, device) = setup();
let block = vec![0u8; 144];
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q4K, &[256], &device)
.unwrap();
let result = client.dequantize(&qt, DType::F32).unwrap();
assert_eq!(result.shape(), &[256]);
let data = result.to_vec::<f32>();
for &v in &data {
assert!(v.abs() < 1e-5);
}
}
#[test]
fn test_dequant_q6k_basic() {
let (client, device) = setup();
let block = vec![0u8; 210];
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q6K, &[256], &device)
.unwrap();
let result = client.dequantize(&qt, DType::F32).unwrap();
assert_eq!(result.shape(), &[256]);
}
#[test]
fn test_dequant_to_f64() {
let (client, device) = setup();
let mut block = [0u8; 18];
block[0..2].copy_from_slice(&f16::from_f32(1.0).to_le_bytes());
block[2..18].fill(0x99);
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q4_0, &[32], &device)
.unwrap();
let result = client.dequantize(&qt, DType::F64).unwrap();
assert_eq!(result.dtype(), DType::F64);
}
#[test]
fn test_dequant_iq1s_basic() {
let (client, device) = setup();
let block = vec![0u8; 50];
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::IQ1S, &[256], &device)
.unwrap();
let result = client.dequantize(&qt, DType::F32);
assert!(result.is_ok());
let data = result.unwrap().to_vec::<f32>();
assert_eq!(data.len(), 256);
for &v in &data {
assert!(v.abs() < 1e-5, "expected 0, got {}", v);
}
}
#[test]
fn test_dequant_invalid_target() {
let (client, device) = setup();
let block = vec![0u8; 18];
let qt = QuantTensor::<CpuRuntime>::from_bytes(&block, QuantFormat::Q4_0, &[32], &device)
.unwrap();
let result = client.dequantize(&qt, DType::I32);
assert!(result.is_err());
}
}