#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[allow(unused_imports)]
use super::NEON_LANE_WIDTH;
pub const INT4_BLOCK_SIZE: usize = 32;
pub const Q4K_SUPER_BLOCK_SIZE: usize = 256;
pub const Q4K_SUB_BLOCKS: usize = 8;
pub const Q4K_SUB_BLOCK_SIZE: usize = Q4K_SUPER_BLOCK_SIZE / Q4K_SUB_BLOCKS;
#[derive(Debug, Clone)]
pub struct QuantizedInt8 {
pub data: Vec<i8>,
pub scale: f32,
}
#[derive(Debug, Clone)]
pub struct QuantizedInt4 {
pub data: Vec<u8>,
pub scales: Vec<f32>,
pub mins: Vec<f32>,
pub block_size: usize,
}
#[derive(Debug, Clone)]
#[repr(C)]
pub struct BlockQ4K {
pub d: u16,
pub dmin: u16,
pub scales: [u8; 12],
pub qs: [u8; 128],
}
pub fn quantize_to_int8(data: &[f32]) -> (Vec<i8>, f32) {
if data.is_empty() {
return (Vec::new(), 1.0);
}
let max_abs = data.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
let inv_scale = 1.0 / scale;
let quantized: Vec<i8> = data
.iter()
.map(|&x| {
let q = (x * inv_scale).round();
q.clamp(-127.0, 127.0) as i8
})
.collect();
(quantized, scale)
}
pub fn dequantize_int8(data: &[i8], scale: f32) -> Vec<f32> {
data.iter().map(|&x| (x as f32) * scale).collect()
}
pub fn quantize_to_int4(data: &[f32], block_size: usize) -> (Vec<u8>, Vec<f32>, Vec<f32>) {
if data.is_empty() {
return (Vec::new(), Vec::new(), Vec::new());
}
let num_blocks = (data.len() + block_size - 1) / block_size;
let mut scales = Vec::with_capacity(num_blocks);
let mut mins = Vec::with_capacity(num_blocks);
let mut packed = Vec::with_capacity((data.len() + 1) / 2);
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(data.len());
let block = &data[start..end];
let (min_val, max_val) = block.iter().fold((f32::MAX, f32::MIN), |(min, max), &x| {
(min.min(x), max.max(x))
});
let scale = if (max_val - min_val).abs() > 1e-10 {
(max_val - min_val) / 15.0
} else {
1.0
};
scales.push(scale);
mins.push(min_val);
let inv_scale = 1.0 / scale;
let mut i = 0;
while i < block.len() {
let q0 = ((block[i] - min_val) * inv_scale).round().clamp(0.0, 15.0) as u8;
let q1 = if i + 1 < block.len() {
((block[i + 1] - min_val) * inv_scale)
.round()
.clamp(0.0, 15.0) as u8
} else {
0
};
packed.push(q0 | (q1 << 4));
i += 2;
}
}
(packed, scales, mins)
}
pub fn dequantize_int4(
packed: &[u8],
scales: &[f32],
mins: &[f32],
block_size: usize,
num_elements: usize,
) -> Vec<f32> {
let mut output = Vec::with_capacity(num_elements);
for block_idx in 0..scales.len() {
let start_byte = (block_idx * block_size) / 2;
let scale = scales[block_idx];
let min = mins[block_idx];
let elements_in_block = if block_idx == scales.len() - 1 {
num_elements - block_idx * block_size
} else {
block_size
};
for i in 0..elements_in_block {
let byte_idx = start_byte + i / 2;
let byte = packed[byte_idx];
let q = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
output.push((q as f32) * scale + min);
}
}
output
}
pub fn quantize_to_q4k(data: &[f32]) -> BlockQ4K {
debug_assert_eq!(data.len(), Q4K_SUPER_BLOCK_SIZE);
let (global_min, global_max) = data.iter().fold((f32::MAX, f32::MIN), |(min, max), &x| {
(min.min(x), max.max(x))
});
let d = f32_to_f16(global_max - global_min);
let dmin = f32_to_f16(global_min);
let mut sub_scales = [0u8; 12];
let global_scale = f16_to_f32(d);
let global_min_f = f16_to_f32(dmin);
for sb in 0..Q4K_SUB_BLOCKS {
let start = sb * Q4K_SUB_BLOCK_SIZE;
let end = start + Q4K_SUB_BLOCK_SIZE;
let sub_block = &data[start..end];
let (sb_min, sb_max) = sub_block
.iter()
.fold((f32::MAX, f32::MIN), |(min, max), &x| {
(min.min(x), max.max(x))
});
let rel_scale = if global_scale > 1e-10 {
((sb_max - sb_min) / global_scale * 63.0)
.round()
.clamp(0.0, 63.0) as u8
} else {
0
};
let byte_idx = (sb * 6) / 8;
let bit_offset = (sb * 6) % 8;
if bit_offset <= 2 {
sub_scales[byte_idx] |= rel_scale << bit_offset;
} else {
sub_scales[byte_idx] |= rel_scale << bit_offset;
if byte_idx + 1 < 12 {
sub_scales[byte_idx + 1] |= rel_scale >> (8 - bit_offset);
}
}
}
let mut qs = [0u8; 128];
let scale = if global_scale > 1e-10 {
global_scale / 15.0
} else {
1.0
};
let inv_scale = 1.0 / scale;
for i in 0..Q4K_SUPER_BLOCK_SIZE {
let q = ((data[i] - global_min_f) * inv_scale)
.round()
.clamp(0.0, 15.0) as u8;
if i % 2 == 0 {
qs[i / 2] = q;
} else {
qs[i / 2] |= q << 4;
}
}
BlockQ4K {
d,
dmin,
scales: sub_scales,
qs,
}
}
#[inline(always)]
pub fn int8_gemv_neon(a: &[i8], x: &[f32], y: &mut [f32], m: usize, n: usize, scale: f32) {
debug_assert_eq!(a.len(), m * n);
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
#[cfg(target_arch = "aarch64")]
unsafe {
int8_gemv_neon_impl(a, x, y, m, n, scale);
}
#[cfg(not(target_arch = "aarch64"))]
{
int8_gemv_scalar(a, x, y, m, n, scale);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn int8_gemv_neon_impl(a: &[i8], x: &[f32], y: &mut [f32], m: usize, n: usize, scale: f32) {
let a_ptr = a.as_ptr();
let x_ptr = x.as_ptr();
let y_ptr = y.as_mut_ptr();
let mut x_max: f32 = 0.0;
for i in 0..n {
x_max = x_max.max((*x_ptr.add(i)).abs());
}
let x_scale = if x_max > 0.0 { x_max / 127.0 } else { 1.0 };
let x_inv_scale = 1.0 / x_scale;
let mut x_i8 = vec![0i8; n];
for i in 0..n {
x_i8[i] = ((*x_ptr.add(i) * x_inv_scale).round().clamp(-127.0, 127.0)) as i8;
}
let x_i8_ptr = x_i8.as_ptr();
let combined_scale = scale * x_scale;
let row_chunks = m / 4;
for rc in 0..row_chunks {
let row_base = rc * 4;
let mut acc0 = vdupq_n_s32(0);
let mut acc1 = vdupq_n_s32(0);
let mut acc2 = vdupq_n_s32(0);
let mut acc3 = vdupq_n_s32(0);
let col_chunks = n / 16;
let mut col = 0usize;
for _ in 0..col_chunks {
let x_v = vld1q_s8(x_i8_ptr.add(col));
let a0 = vld1q_s8(a_ptr.add((row_base + 0) * n + col));
let a1 = vld1q_s8(a_ptr.add((row_base + 1) * n + col));
let a2 = vld1q_s8(a_ptr.add((row_base + 2) * n + col));
let a3 = vld1q_s8(a_ptr.add((row_base + 3) * n + col));
let a0_lo = vget_low_s8(a0);
let a0_hi = vget_high_s8(a0);
let x_lo = vget_low_s8(x_v);
let x_hi = vget_high_s8(x_v);
let prod0_lo = vmull_s8(a0_lo, x_lo);
let prod0_hi = vmull_s8(a0_hi, x_hi);
acc0 = vpadalq_s16(acc0, prod0_lo);
acc0 = vpadalq_s16(acc0, prod0_hi);
let a1_lo = vget_low_s8(a1);
let a1_hi = vget_high_s8(a1);
let prod1_lo = vmull_s8(a1_lo, x_lo);
let prod1_hi = vmull_s8(a1_hi, x_hi);
acc1 = vpadalq_s16(acc1, prod1_lo);
acc1 = vpadalq_s16(acc1, prod1_hi);
let a2_lo = vget_low_s8(a2);
let a2_hi = vget_high_s8(a2);
let prod2_lo = vmull_s8(a2_lo, x_lo);
let prod2_hi = vmull_s8(a2_hi, x_hi);
acc2 = vpadalq_s16(acc2, prod2_lo);
acc2 = vpadalq_s16(acc2, prod2_hi);
let a3_lo = vget_low_s8(a3);
let a3_hi = vget_high_s8(a3);
let prod3_lo = vmull_s8(a3_lo, x_lo);
let prod3_hi = vmull_s8(a3_hi, x_hi);
acc3 = vpadalq_s16(acc3, prod3_lo);
acc3 = vpadalq_s16(acc3, prod3_hi);
col += 16;
}
let mut sum0 = vaddvq_s32(acc0);
let mut sum1 = vaddvq_s32(acc1);
let mut sum2 = vaddvq_s32(acc2);
let mut sum3 = vaddvq_s32(acc3);
for c in col..n {
let x_val = *x_i8_ptr.add(c) as i32;
sum0 += (*a_ptr.add((row_base + 0) * n + c) as i32) * x_val;
sum1 += (*a_ptr.add((row_base + 1) * n + c) as i32) * x_val;
sum2 += (*a_ptr.add((row_base + 2) * n + c) as i32) * x_val;
sum3 += (*a_ptr.add((row_base + 3) * n + c) as i32) * x_val;
}
*y_ptr.add(row_base + 0) = (sum0 as f32) * combined_scale;
*y_ptr.add(row_base + 1) = (sum1 as f32) * combined_scale;
*y_ptr.add(row_base + 2) = (sum2 as f32) * combined_scale;
*y_ptr.add(row_base + 3) = (sum3 as f32) * combined_scale;
}
for row in (row_chunks * 4)..m {
let mut acc = vdupq_n_s32(0);
let col_chunks = n / 16;
let mut col = 0usize;
for _ in 0..col_chunks {
let x_v = vld1q_s8(x_i8_ptr.add(col));
let a_v = vld1q_s8(a_ptr.add(row * n + col));
let a_lo = vget_low_s8(a_v);
let a_hi = vget_high_s8(a_v);
let x_lo = vget_low_s8(x_v);
let x_hi = vget_high_s8(x_v);
let prod_lo = vmull_s8(a_lo, x_lo);
let prod_hi = vmull_s8(a_hi, x_hi);
acc = vpadalq_s16(acc, prod_lo);
acc = vpadalq_s16(acc, prod_hi);
col += 16;
}
let mut sum = vaddvq_s32(acc);
for c in col..n {
sum += (*a_ptr.add(row * n + c) as i32) * (*x_i8_ptr.add(c) as i32);
}
*y_ptr.add(row) = (sum as f32) * combined_scale;
}
}
#[allow(dead_code)]
fn int8_gemv_scalar(a: &[i8], x: &[f32], y: &mut [f32], m: usize, n: usize, scale: f32) {
let x_max = x.iter().fold(0.0f32, |acc, &v| acc.max(v.abs()));
let x_scale = if x_max > 0.0 { x_max / 127.0 } else { 1.0 };
let x_inv_scale = 1.0 / x_scale;
let x_i8: Vec<i8> = x
.iter()
.map(|&v| (v * x_inv_scale).round().clamp(-127.0, 127.0) as i8)
.collect();
let combined_scale = scale * x_scale;
for row in 0..m {
let mut sum: i32 = 0;
for col in 0..n {
sum += (a[row * n + col] as i32) * (x_i8[col] as i32);
}
y[row] = (sum as f32) * combined_scale;
}
}
#[inline(always)]
pub fn int4_gemv_neon(
a: &[u8],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
scales: &[f32],
mins: &[f32],
block_size: usize,
) {
debug_assert_eq!(a.len(), m * ((n + 1) / 2));
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
#[cfg(target_arch = "aarch64")]
unsafe {
int4_gemv_neon_impl(a, x, y, m, n, scales, mins, block_size);
}
#[cfg(not(target_arch = "aarch64"))]
{
int4_gemv_scalar(a, x, y, m, n, scales, mins, block_size);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn int4_gemv_neon_impl(
a: &[u8],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
scales: &[f32],
mins: &[f32],
block_size: usize,
) {
let a_ptr = a.as_ptr();
let x_ptr = x.as_ptr();
let y_ptr = y.as_mut_ptr();
let row_bytes = (n + 1) / 2;
let blocks_per_row = (n + block_size - 1) / block_size;
let low_mask = vdupq_n_u8(0x0F);
for row in 0..m {
let mut acc = vdupq_n_f32(0.0);
let mut scalar_acc: f32 = 0.0;
let row_start = row * row_bytes;
for block_idx in 0..blocks_per_row {
let block_start_elem = block_idx * block_size;
let block_start_byte = block_start_elem / 2;
let elements_in_block = (n - block_start_elem).min(block_size);
let scale = scales[row * blocks_per_row + block_idx];
let min = mins[row * blocks_per_row + block_idx];
let scale_vec = vdupq_n_f32(scale);
let min_vec = vdupq_n_f32(min);
let mut elem = 0usize;
while elem + 8 <= elements_in_block {
let byte_offset = row_start + block_start_byte + elem / 2;
let packed = vld1_u8(a_ptr.add(byte_offset));
let low = vand_u8(packed, vget_low_u8(low_mask));
let high = vshr_n_u8(packed, 4);
let unpacked_lo = vzip1_u8(low, high);
let _unpacked_hi = vzip2_u8(low, high);
let q0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(unpacked_lo))));
let q1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(unpacked_lo))));
let deq0 = vfmaq_f32(min_vec, q0, scale_vec);
let deq1 = vfmaq_f32(min_vec, q1, scale_vec);
let x0 = vld1q_f32(x_ptr.add(block_start_elem + elem));
let x1 = vld1q_f32(x_ptr.add(block_start_elem + elem + 4));
acc = vfmaq_f32(acc, deq0, x0);
acc = vfmaq_f32(acc, deq1, x1);
elem += 8;
}
while elem < elements_in_block {
let byte_idx = row_start + block_start_byte + elem / 2;
let byte = *a_ptr.add(byte_idx);
let q = if elem % 2 == 0 {
byte & 0x0F
} else {
byte >> 4
};
let val = (q as f32) * scale + min;
scalar_acc += val * *x_ptr.add(block_start_elem + elem);
elem += 1;
}
}
*y_ptr.add(row) = vaddvq_f32(acc) + scalar_acc;
}
}
#[allow(dead_code)]
fn int4_gemv_scalar(
a: &[u8],
x: &[f32],
y: &mut [f32],
m: usize,
n: usize,
scales: &[f32],
mins: &[f32],
block_size: usize,
) {
let row_bytes = (n + 1) / 2;
let blocks_per_row = (n + block_size - 1) / block_size;
for row in 0..m {
let mut sum: f32 = 0.0;
let row_start = row * row_bytes;
for block_idx in 0..blocks_per_row {
let block_start_elem = block_idx * block_size;
let block_start_byte = block_start_elem / 2;
let elements_in_block = (n - block_start_elem).min(block_size);
let scale = scales[row * blocks_per_row + block_idx];
let min = mins[row * blocks_per_row + block_idx];
for elem in 0..elements_in_block {
let byte_idx = row_start + block_start_byte + elem / 2;
let byte = a[byte_idx];
let q = if elem % 2 == 0 {
byte & 0x0F
} else {
byte >> 4
};
let val = (q as f32) * scale + min;
sum += val * x[block_start_elem + elem];
}
}
y[row] = sum;
}
}
#[inline(always)]
pub fn q4k_gemv_neon(blocks: &[BlockQ4K], x: &[f32], y: &mut [f32], m: usize, n: usize) {
debug_assert_eq!(n % Q4K_SUPER_BLOCK_SIZE, 0);
let blocks_per_row = n / Q4K_SUPER_BLOCK_SIZE;
debug_assert_eq!(blocks.len(), m * blocks_per_row);
debug_assert_eq!(x.len(), n);
debug_assert_eq!(y.len(), m);
#[cfg(target_arch = "aarch64")]
unsafe {
q4k_gemv_neon_impl(blocks, x, y, m, n, blocks_per_row);
}
#[cfg(not(target_arch = "aarch64"))]
{
q4k_gemv_scalar(blocks, x, y, m, n, blocks_per_row);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn q4k_gemv_neon_impl(
blocks: &[BlockQ4K],
x: &[f32],
y: &mut [f32],
m: usize,
_n: usize,
blocks_per_row: usize,
) {
let x_ptr = x.as_ptr();
let y_ptr = y.as_mut_ptr();
let low_mask = vdupq_n_u8(0x0F);
for row in 0..m {
let mut acc = vdupq_n_f32(0.0);
for block_idx in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + block_idx];
let x_offset = block_idx * Q4K_SUPER_BLOCK_SIZE;
let d = f16_to_f32(block.d);
let dmin = f16_to_f32(block.dmin);
let scale_vec = vdupq_n_f32(d / 15.0);
let min_vec = vdupq_n_f32(dmin);
for i in (0..Q4K_SUPER_BLOCK_SIZE).step_by(8) {
let byte_idx = i / 2;
let b0 = block.qs[byte_idx];
let b1 = block.qs[byte_idx + 1];
let b2 = block.qs[byte_idx + 2];
let b3 = block.qs[byte_idx + 3];
let packed = vld1_u8([b0, b1, b2, b3, 0, 0, 0, 0].as_ptr());
let low = vand_u8(packed, vget_low_u8(low_mask));
let high = vshr_n_u8(packed, 4);
let unpacked_lo = vzip1_u8(low, high);
let q0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(unpacked_lo))));
let q1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(unpacked_lo))));
let deq0 = vfmaq_f32(min_vec, q0, scale_vec);
let deq1 = vfmaq_f32(min_vec, q1, scale_vec);
let x0 = vld1q_f32(x_ptr.add(x_offset + i));
let x1 = vld1q_f32(x_ptr.add(x_offset + i + 4));
acc = vfmaq_f32(acc, deq0, x0);
acc = vfmaq_f32(acc, deq1, x1);
}
}
*y_ptr.add(row) = vaddvq_f32(acc);
}
}
#[allow(dead_code)]
fn q4k_gemv_scalar(
blocks: &[BlockQ4K],
x: &[f32],
y: &mut [f32],
m: usize,
_n: usize,
blocks_per_row: usize,
) {
for row in 0..m {
let mut sum: f32 = 0.0;
for block_idx in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + block_idx];
let x_offset = block_idx * Q4K_SUPER_BLOCK_SIZE;
let d = f16_to_f32(block.d);
let dmin = f16_to_f32(block.dmin);
let scale = d / 15.0;
for i in 0..Q4K_SUPER_BLOCK_SIZE {
let byte_idx = i / 2;
let byte = block.qs[byte_idx];
let q = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
let val = (q as f32) * scale + dmin;
sum += val * x[x_offset + i];
}
}
y[row] = sum;
}
}
#[inline(always)]
fn f32_to_f16(x: f32) -> u16 {
let bits = x.to_bits();
let sign = (bits >> 16) & 0x8000;
let exp = ((bits >> 23) & 0xFF) as i32;
let frac = bits & 0x007F_FFFF;
if exp == 0xFF {
return (sign | 0x7C00 | ((frac != 0) as u32 * 0x0200)) as u16;
}
let new_exp = exp - 127 + 15;
if new_exp >= 31 {
return (sign | 0x7C00) as u16;
}
if new_exp <= 0 {
if new_exp < -10 {
return sign as u16;
}
let frac = (frac | 0x0080_0000) >> (14 - new_exp);
return (sign | (frac >> 13)) as u16;
}
(sign | ((new_exp as u32) << 10) | (frac >> 13)) as u16
}
#[inline(always)]
fn f16_to_f32(x: u16) -> f32 {
let sign = ((x & 0x8000) as u32) << 16;
let exp = ((x >> 10) & 0x1F) as u32;
let frac = (x & 0x03FF) as u32;
if exp == 0 {
if frac == 0 {
return f32::from_bits(sign);
}
let mut e = 1u32;
let mut f = frac;
while (f & 0x0400) == 0 {
f <<= 1;
e += 1;
}
f &= 0x03FF;
return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | (f << 13));
}
if exp == 31 {
return f32::from_bits(sign | 0x7F80_0000 | (frac << 13));
}
f32::from_bits(sign | ((exp + 127 - 15) << 23) | (frac << 13))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_int8_quantization_roundtrip() {
let data = vec![0.5, -0.3, 1.0, -0.8, 0.0, 0.25, -0.125, 0.75];
let (quantized, scale) = quantize_to_int8(&data);
let dequantized = dequantize_int8(&quantized, scale);
for (orig, deq) in data.iter().zip(dequantized.iter()) {
let error = (orig - deq).abs() / orig.abs().max(0.01);
assert!(
error < 0.02,
"INT8 quantization error too high: {} vs {}",
orig,
deq
);
}
}
#[test]
fn test_int4_quantization_roundtrip() {
let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 32.0).collect();
let (packed, scales, mins) = quantize_to_int4(&data, INT4_BLOCK_SIZE);
let dequantized = dequantize_int4(&packed, &scales, &mins, INT4_BLOCK_SIZE, data.len());
for (orig, deq) in data.iter().zip(dequantized.iter()) {
let error = (orig - deq).abs();
assert!(
error < 0.1,
"INT4 quantization error too high: {} vs {}",
orig,
deq
);
}
}
#[test]
fn test_int8_gemv_accuracy() {
let m = 32;
let n = 64;
let a_f32: Vec<f32> = (0..m * n).map(|i| ((i % 7) as f32 - 3.0) / 10.0).collect();
let x: Vec<f32> = (0..n).map(|i| ((i % 5) as f32 - 2.0) / 5.0).collect();
let (a_i8, scale) = quantize_to_int8(&a_f32);
let mut y_quant = vec![0.0f32; m];
int8_gemv_neon(&a_i8, &x, &mut y_quant, m, n, scale);
let mut y_ref = vec![0.0f32; m];
for row in 0..m {
for col in 0..n {
y_ref[row] += a_f32[row * n + col] * x[col];
}
}
for i in 0..m {
let rel_error = (y_quant[i] - y_ref[i]).abs() / y_ref[i].abs().max(0.01);
let abs_error = (y_quant[i] - y_ref[i]).abs();
assert!(
rel_error < 0.03 || abs_error < 0.01,
"INT8 GEMV error at row {}: {} vs {} (rel: {:.4}, abs: {:.6})",
i,
y_quant[i],
y_ref[i],
rel_error,
abs_error
);
}
}
#[test]
fn test_int4_gemv_accuracy() {
let m = 16;
let n = 64;
let block_size = INT4_BLOCK_SIZE;
let a_f32: Vec<f32> = (0..m * n).map(|i| ((i % 11) as f32 - 5.0) / 10.0).collect();
let x: Vec<f32> = (0..n).map(|i| ((i % 7) as f32 - 3.0) / 5.0).collect();
let blocks_per_row = (n + block_size - 1) / block_size;
let mut all_packed = Vec::new();
let mut all_scales = Vec::new();
let mut all_mins = Vec::new();
for row in 0..m {
let row_data = &a_f32[row * n..(row + 1) * n];
let (packed, scales, mins) = quantize_to_int4(row_data, block_size);
all_packed.extend(packed);
all_scales.extend(scales);
all_mins.extend(mins);
}
let mut y_quant = vec![0.0f32; m];
int4_gemv_neon(
&all_packed,
&x,
&mut y_quant,
m,
n,
&all_scales,
&all_mins,
block_size,
);
let mut y_ref = vec![0.0f32; m];
for row in 0..m {
for col in 0..n {
y_ref[row] += a_f32[row * n + col] * x[col];
}
}
for i in 0..m {
let rel_error = (y_quant[i] - y_ref[i]).abs() / y_ref[i].abs().max(0.01);
let abs_error = (y_quant[i] - y_ref[i]).abs();
assert!(
rel_error < 0.10 || abs_error < 0.1,
"INT4 GEMV error at row {}: {} vs {} (rel: {:.4}, abs: {:.6})",
i,
y_quant[i],
y_ref[i],
rel_error,
abs_error
);
}
}
#[test]
fn test_q4k_structure() {
assert_eq!(std::mem::size_of::<BlockQ4K>(), 2 + 2 + 12 + 128);
}
#[test]
fn test_f16_conversion() {
let values = [0.0f32, 1.0, -1.0, 0.5, 65504.0, 0.00006103515625];
for &v in &values {
let h = f32_to_f16(v);
let back = f16_to_f32(h);
let error = (v - back).abs() / v.abs().max(1e-6);
assert!(
error < 0.01 || (v - back).abs() < 1e-6,
"F16 roundtrip error: {} -> {} -> {}",
v,
h,
back
);
}
}
#[test]
fn test_q4k_quantization() {
let data: Vec<f32> = (0..Q4K_SUPER_BLOCK_SIZE)
.map(|i| ((i as f32) - 128.0) / 128.0)
.collect();
let block = quantize_to_q4k(&data);
assert!(f16_to_f32(block.d) > 0.0);
let scale = f16_to_f32(block.d) / 15.0;
let min = f16_to_f32(block.dmin);
for i in 0..8 {
let byte_idx = i / 2;
let q = if i % 2 == 0 {
block.qs[byte_idx] & 0x0F
} else {
block.qs[byte_idx] >> 4
};
let deq = (q as f32) * scale + min;
let orig = data[i];
let error = (deq - orig).abs();
assert!(error < 0.2, "Q4_K error at {}: {} vs {}", i, deq, orig);
}
}
#[test]
fn test_int8_gemv_large() {
let m = 128;
let n = 512;
let a_f32: Vec<f32> = (0..m * n)
.map(|i| ((i % 127) as f32 - 63.0) / 100.0)
.collect();
let x: Vec<f32> = (0..n).map(|i| ((i % 63) as f32 - 31.0) / 50.0).collect();
let (a_i8, scale) = quantize_to_int8(&a_f32);
let mut y_quant = vec![0.0f32; m];
int8_gemv_neon(&a_i8, &x, &mut y_quant, m, n, scale);
let a_deq = dequantize_int8(&a_i8, scale);
let mut y_ref = vec![0.0f32; m];
for row in 0..m {
for col in 0..n {
y_ref[row] += a_deq[row * n + col] * x[col];
}
}
assert!(y_quant.iter().all(|&v| v.is_finite()));
assert!(y_ref.iter().all(|&v| v.is_finite()));
for &i in &[0, m / 2, m - 1] {
let abs_error = (y_quant[i] - y_ref[i]).abs();
let tolerance = y_ref[i].abs() * 0.15 + 0.1;
assert!(
abs_error < tolerance,
"Large INT8 GEMV error at row {}: {} vs {} (abs: {:.6}, tol: {:.6})",
i,
y_quant[i],
y_ref[i],
abs_error,
tolerance
);
}
}
#[test]
fn test_int4_block_boundary() {
let block_size = INT4_BLOCK_SIZE;
let n = block_size * 2 + 7;
let data: Vec<f32> = (0..n).map(|i| (i as f32) / (n as f32)).collect();
let (packed, scales, mins) = quantize_to_int4(&data, block_size);
let dequantized = dequantize_int4(&packed, &scales, &mins, block_size, data.len());
assert_eq!(dequantized.len(), n);
for &i in &[0, block_size - 1, block_size, block_size * 2 - 1, n - 1] {
let error = (data[i] - dequantized[i]).abs();
assert!(
error < 0.15,
"INT4 boundary error at {}: {} vs {}",
i,
data[i],
dequantized[i]
);
}
}
}