use rayon::prelude::*;
use super::{dequant, dequant_k_quants};
use crate::quant::QuantFormat;
fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "x86_64")]
{
let len = a.len();
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { super::simd::dot_f32::dot_f32_avx2_fma(a.as_ptr(), b.as_ptr(), len) };
}
a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
}
#[cfg(target_arch = "aarch64")]
unsafe {
super::simd::aarch64::dot_f32::dot_f32_neon(a.as_ptr(), b.as_ptr(), a.len())
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let mut sum = 0.0f32;
for (&ai, &bi) in a.iter().zip(b.iter()) {
sum += ai * bi;
}
sum
}
}
pub fn quant_matmul_f32(
act: &[f32],
weight_bytes: &[u8],
output: &mut [f32],
m: usize,
k: usize,
n: usize,
format: QuantFormat,
) {
debug_assert_eq!(act.len(), m * k);
debug_assert_eq!(output.len(), m * n);
let block_size = format.block_size();
let block_bytes = format.block_bytes();
let blocks_per_row = k / block_size;
let row_bytes = blocks_per_row * block_bytes;
debug_assert_eq!(weight_bytes.len(), n * row_bytes);
let num_threads = rayon::current_num_threads();
let target_chunks = if m == 1 { num_threads } else { num_threads * 4 };
let chunk_size = n.div_ceil(target_chunks);
let chunk_size = chunk_size.max(16);
let col_ranges: Vec<(usize, usize)> = (0..n)
.step_by(chunk_size)
.map(|start| (start, (start + chunk_size).min(n)))
.collect();
let output_ptr = output.as_mut_ptr() as usize;
let use_fused = matches!(
format,
QuantFormat::Q2K
| QuantFormat::Q3K
| QuantFormat::Q4K
| QuantFormat::Q5K
| QuantFormat::Q6K
);
let use_q8k = use_fused && k % 256 == 0;
let q8k_block_bytes = super::simd::quantize_act_q8k::Q8K_BLOCK_BYTES;
let q8k_blocks_per_row = k / 256;
let q8k_row_size = q8k_blocks_per_row * q8k_block_bytes;
let act_q8k: Vec<u8> = if use_q8k {
let mut buf = vec![0u8; m * q8k_row_size];
for i in 0..m {
let act_row = &act[i * k..(i + 1) * k];
let q8k_row = &mut buf[i * q8k_row_size..(i + 1) * q8k_row_size];
super::simd::quantize_act_q8k::quantize_f32_to_q8k(act_row, q8k_row);
}
buf
} else {
Vec::new()
};
let act_q8k_ptr = act_q8k.as_ptr() as usize;
col_ranges.par_iter().for_each(|&(j_start, j_end)| {
let out = output_ptr as *mut f32;
if use_q8k {
for i in 0..m {
let q8k_row = unsafe {
std::slice::from_raw_parts(
(act_q8k_ptr as *const u8).add(i * q8k_row_size),
q8k_row_size,
)
};
for j in j_start..j_end {
let row_data = &weight_bytes[j * row_bytes..(j + 1) * row_bytes];
let val = fused_dot_q8k_dispatch(q8k_row, row_data, k, format);
unsafe {
*out.add(i * n + j) = val;
}
}
}
} else if use_fused {
let cols = j_end - j_start;
let pairs = cols / 2;
let remainder = cols % 2;
for i in 0..m {
let act_row = &act[i * k..(i + 1) * k];
for p in 0..pairs {
let j0 = j_start + p * 2;
let j1 = j0 + 1;
let row_data0 = &weight_bytes[j0 * row_bytes..(j0 + 1) * row_bytes];
let row_data1 = &weight_bytes[j1 * row_bytes..(j1 + 1) * row_bytes];
let val0 = fused_dot_dispatch(act_row, row_data0, k, format);
let val1 = fused_dot_dispatch(act_row, row_data1, k, format);
unsafe {
*out.add(i * n + j0) = val0;
*out.add(i * n + j1) = val1;
}
}
if remainder > 0 {
let j = j_end - 1;
let row_data = &weight_bytes[j * row_bytes..(j + 1) * row_bytes];
let val = fused_dot_dispatch(act_row, row_data, k, format);
unsafe {
*out.add(i * n + j) = val;
}
}
}
} else {
let mut dequant_row = vec![0.0f32; k];
for j in j_start..j_end {
let row_start = j * row_bytes;
let row_data = &weight_bytes[row_start..row_start + row_bytes];
dequant_row_f32(row_data, &mut dequant_row, format);
for i in 0..m {
let act_row = &act[i * k..(i + 1) * k];
let val = dot_f32(act_row, &dequant_row);
unsafe {
*out.add(i * n + j) = val;
}
}
}
}
});
}
fn fused_dot_dispatch(act_row: &[f32], row_data: &[u8], k: usize, format: QuantFormat) -> f32 {
match format {
QuantFormat::Q2K => super::simd::fused_q2k_dot::fused_dot_q2k(act_row, row_data, k),
QuantFormat::Q3K => super::simd::fused_q3k_dot::fused_dot_q3k(act_row, row_data, k),
QuantFormat::Q4K => super::simd::fused_q4k_dot::fused_dot_q4k(act_row, row_data, k),
QuantFormat::Q5K => super::simd::fused_q5k_dot::fused_dot_q5k(act_row, row_data, k),
QuantFormat::Q6K => super::simd::fused_q6k_dot::fused_dot_q6k(act_row, row_data, k),
_ => unreachable!(),
}
}
fn fused_dot_q8k_dispatch(act_q8k: &[u8], row_data: &[u8], k: usize, format: QuantFormat) -> f32 {
match format {
QuantFormat::Q2K => super::simd::fused_q2k_q8k_dot::fused_dot_q2k_q8k(act_q8k, row_data, k),
QuantFormat::Q3K => super::simd::fused_q3k_q8k_dot::fused_dot_q3k_q8k(act_q8k, row_data, k),
QuantFormat::Q4K => super::simd::fused_q4k_q8k_dot::fused_dot_q4k_q8k(act_q8k, row_data, k),
QuantFormat::Q5K => super::simd::fused_q5k_q8k_dot::fused_dot_q5k_q8k(act_q8k, row_data, k),
QuantFormat::Q6K => super::simd::fused_q6k_q8k_dot::fused_dot_q6k_q8k(act_q8k, row_data, k),
_ => unreachable!(),
}
}
pub fn quant_matmul_batch_f32(
act: &[f32],
weight_list: &[(&[u8], usize)], outputs: &mut [&mut [f32]],
m: usize,
k: usize,
format: QuantFormat,
) {
let block_size = format.block_size();
let block_bytes = format.block_bytes();
let blocks_per_row = k / block_size;
let row_bytes = blocks_per_row * block_bytes;
let use_fused = matches!(
format,
QuantFormat::Q2K
| QuantFormat::Q3K
| QuantFormat::Q4K
| QuantFormat::Q5K
| QuantFormat::Q6K
);
let use_q8k = use_fused && k % 256 == 0;
let q8k_block_bytes = super::simd::quantize_act_q8k::Q8K_BLOCK_BYTES;
let q8k_blocks_per_row = k / 256;
let q8k_row_size = q8k_blocks_per_row * q8k_block_bytes;
let act_q8k: Vec<u8> = if use_q8k {
let mut buf = vec![0u8; m * q8k_row_size];
for i in 0..m {
let act_row = &act[i * k..(i + 1) * k];
let q8k_row = &mut buf[i * q8k_row_size..(i + 1) * q8k_row_size];
super::simd::quantize_act_q8k::quantize_f32_to_q8k(act_row, q8k_row);
}
buf
} else {
Vec::new()
};
let act_q8k_ptr = act_q8k.as_ptr() as usize;
let max_n: usize = weight_list.iter().map(|&(_, n)| n).max().unwrap_or(0);
if max_n == 0 {
return;
}
let num_threads = rayon::current_num_threads();
let target_chunks = if m == 1 { num_threads } else { num_threads * 4 };
let chunk_size = max_n.div_ceil(target_chunks);
let chunk_size = chunk_size.max(16);
let output_ptrs: Vec<(usize, usize)> = outputs
.iter()
.zip(weight_list.iter())
.map(|(out, &(_, n))| (out.as_ptr() as usize, n))
.collect();
let weight_ptrs: Vec<(usize, usize)> = weight_list
.iter()
.map(|&(w, n)| (w.as_ptr() as usize, n))
.collect();
let col_ranges: Vec<(usize, usize)> = (0..max_n)
.step_by(chunk_size)
.map(|start| (start, (start + chunk_size).min(max_n)))
.collect();
col_ranges.par_iter().for_each(|&(j_start, j_end)| {
for i in 0..m {
for (w_idx, &(w_ptr, n)) in weight_ptrs.iter().enumerate() {
let (out_ptr, _) = output_ptrs[w_idx];
let out = out_ptr as *mut f32;
let w_base = w_ptr as *const u8;
let j_end_clamped = j_end.min(n);
if j_start >= n {
continue;
}
if use_q8k {
let q8k_row = unsafe {
std::slice::from_raw_parts(
(act_q8k_ptr as *const u8).add(i * q8k_row_size),
q8k_row_size,
)
};
for j in j_start..j_end_clamped {
let row_data = unsafe {
std::slice::from_raw_parts(w_base.add(j * row_bytes), row_bytes)
};
let val = fused_dot_q8k_dispatch(q8k_row, row_data, k, format);
unsafe {
*out.add(i * n + j) = val;
}
}
} else if use_fused {
let act_row = &act[i * k..(i + 1) * k];
for j in j_start..j_end_clamped {
let row_data = unsafe {
std::slice::from_raw_parts(w_base.add(j * row_bytes), row_bytes)
};
let val = fused_dot_dispatch(act_row, row_data, k, format);
unsafe {
*out.add(i * n + j) = val;
}
}
} else {
let act_row = &act[i * k..(i + 1) * k];
let mut dequant_row = vec![0.0f32; k];
for j in j_start..j_end_clamped {
let row_data = unsafe {
std::slice::from_raw_parts(w_base.add(j * row_bytes), row_bytes)
};
dequant_row_f32(row_data, &mut dequant_row, format);
let val = dot_f32(act_row, &dequant_row);
unsafe {
*out.add(i * n + j) = val;
}
}
}
}
}
});
}
pub fn dequant_row_f32(row_bytes: &[u8], output: &mut [f32], format: QuantFormat) {
match format {
QuantFormat::Q4_0 => dequant::dequant_q4_0(row_bytes, output),
QuantFormat::Q4_1 => dequant::dequant_q4_1(row_bytes, output),
QuantFormat::Q5_0 => dequant::dequant_q5_0(row_bytes, output),
QuantFormat::Q5_1 => dequant::dequant_q5_1(row_bytes, output),
QuantFormat::Q8_0 => dequant::dequant_q8_0(row_bytes, output),
QuantFormat::Q8_1 => dequant::dequant_q8_1(row_bytes, output),
QuantFormat::Q2K => dequant_k_quants::dequant_q2k(row_bytes, output),
QuantFormat::Q3K => dequant_k_quants::dequant_q3k(row_bytes, output),
QuantFormat::Q4K => dequant::dequant_q4k(row_bytes, output),
QuantFormat::Q5K => dequant_k_quants::dequant_q5k(row_bytes, output),
QuantFormat::Q6K => dequant::dequant_q6k(row_bytes, output),
QuantFormat::Q8K => dequant_k_quants::dequant_q8k(row_bytes, output),
QuantFormat::IQ4NL => dequant::dequant_iq4_nl(row_bytes, output),
QuantFormat::IQ4XS => dequant::dequant_iq4_xs(row_bytes, output),
QuantFormat::IQ2XXS => dequant::dequant_iq2_xxs(row_bytes, output),
QuantFormat::IQ2XS => dequant::dequant_iq2_xs(row_bytes, output),
QuantFormat::IQ2S => dequant::dequant_iq2_s(row_bytes, output),
QuantFormat::IQ3XXS => dequant::dequant_iq3_xxs(row_bytes, output),
QuantFormat::IQ3S => dequant::dequant_iq3_s(row_bytes, output),
QuantFormat::IQ1S => dequant::dequant_iq1_s(row_bytes, output),
QuantFormat::IQ1M => dequant::dequant_iq1_m(row_bytes, output),
QuantFormat::TQ1_0 => dequant::dequant_tq1_0(row_bytes, output),
QuantFormat::TQ2_0 => dequant::dequant_tq2_0(row_bytes, output),
}
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
#[test]
fn test_quant_matmul_q4_0_identity_like() {
let m = 1;
let k = 32;
let n = 1;
let act = vec![1.0f32; m * k];
let mut block = [0u8; 18];
block[0..2].copy_from_slice(&f16::from_f32(2.0).to_le_bytes());
block[2..18].fill(0x99);
let mut output = vec![0.0f32; m * n];
quant_matmul_f32(&act, &block, &mut output, m, k, n, QuantFormat::Q4_0);
assert!(
(output[0] - 64.0).abs() < 0.5,
"expected ~64.0, got {}",
output[0]
);
}
#[test]
fn test_quant_matmul_q8_0_2x1() {
let m = 2;
let k = 32;
let n = 1;
let mut act = vec![0.0f32; m * k];
act[..k].fill(1.0);
act[k..].fill(0.5);
let mut block = [0u8; 34];
block[0..2].copy_from_slice(&f16::from_f32(0.5).to_le_bytes());
block[2..34].fill(4);
let mut output = vec![0.0f32; m * n];
quant_matmul_f32(&act, &block, &mut output, m, k, n, QuantFormat::Q8_0);
assert!(
(output[0] - 64.0).abs() < 0.5,
"expected ~64.0, got {}",
output[0]
);
assert!(
(output[1] - 32.0).abs() < 0.5,
"expected ~32.0, got {}",
output[1]
);
}
#[test]
fn test_quant_matmul_multiple_output_cols() {
let m = 1;
let k = 32;
let n = 2;
let act = vec![1.0f32; m * k];
let mut block0 = [0u8; 18];
block0[0..2].copy_from_slice(&f16::from_f32(1.0).to_le_bytes());
block0[2..18].fill(0x99);
let mut block1 = [0u8; 18];
block1[0..2].copy_from_slice(&f16::from_f32(3.0).to_le_bytes());
block1[2..18].fill(0x99);
let mut weight_bytes = Vec::new();
weight_bytes.extend_from_slice(&block0);
weight_bytes.extend_from_slice(&block1);
let mut output = vec![0.0f32; m * n];
quant_matmul_f32(&act, &weight_bytes, &mut output, m, k, n, QuantFormat::Q4_0);
assert!(
(output[0] - 32.0).abs() < 0.5,
"expected ~32.0, got {}",
output[0]
);
assert!(
(output[1] - 96.0).abs() < 0.5,
"expected ~96.0, got {}",
output[1]
);
}
}