use super::bsum_precompute::{fused_q4k_q8k_dot_with_bsums_simd, precompute_q8k_bsums};
use super::format_trait::{Q5K, Q6K};
#[cfg(target_arch = "x86_64")]
use super::fused_k::fused_q4k_q8k_dot_4rows_avx512vnni;
use super::fused_k::{fused_q4k_dot_simd, fused_q4k_q8k_dot_simd};
use super::fused_q5k_q6k::{fused_q5k_dot_simd, fused_q6k_dot_simd};
use super::generic_matvec::{generic_parallel_matvec, generic_parallel_matvec_into};
use super::types::QK_K;
use crate::error::{RealizarError, Result};
use std::borrow::Cow;
const DEFAULT_OUTPUT_TILE_SIZE: usize = 64;
#[inline]
fn pad_activations(activations: &[f32], padded_len: usize) -> Cow<'_, [f32]> {
if activations.len() == padded_len {
Cow::Borrowed(activations)
} else {
let mut padded = vec![0.0f32; padded_len];
padded[..activations.len()].copy_from_slice(activations);
Cow::Owned(padded)
}
}
#[allow(clippy::similar_names)]
pub fn fused_q4k_tiled_matvec(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
tile_size: Option<usize>,
) -> Result<Vec<f32>> {
let tile_size = tile_size.unwrap_or(DEFAULT_OUTPUT_TILE_SIZE);
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let bytes_per_row = super_blocks_per_row * 144;
let expected_weight_bytes = out_dim * bytes_per_row;
if weight_data.len() < expected_weight_bytes {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K weight data too small: need {} bytes for {}x{}, have {}",
expected_weight_bytes,
out_dim,
in_dim,
weight_data.len()
),
});
}
if activations.len() != in_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Activation length {} doesn't match in_dim {}",
activations.len(),
in_dim
),
});
}
let padded_in_dim = super_blocks_per_row * QK_K;
let acts = pad_activations(activations, padded_in_dim);
let mut output = vec![0.0f32; out_dim];
let num_tiles = out_dim.div_ceil(tile_size);
for tile_idx in 0..num_tiles {
let tile_start = tile_idx * tile_size;
let tile_end = (tile_start + tile_size).min(out_dim);
#[cfg(target_arch = "x86_64")]
if tile_idx + 1 < num_tiles {
let next_tile_start = (tile_idx + 1) * tile_size;
let next_row_start = next_tile_start * bytes_per_row;
if next_row_start < weight_data.len() {
unsafe {
use std::arch::x86_64::_mm_prefetch;
use std::arch::x86_64::_MM_HINT_T0;
let ptr = weight_data.as_ptr().add(next_row_start);
_mm_prefetch(ptr.cast::<i8>(), _MM_HINT_T0);
}
}
}
for (idx, out_slot) in output[tile_start..tile_end].iter_mut().enumerate() {
let o = tile_start + idx;
let row_start = o * bytes_per_row;
let row_end = row_start + bytes_per_row;
let row_data = &weight_data[row_start..row_end];
*out_slot = fused_q4k_dot_simd(row_data, &acts)?;
}
}
Ok(output)
}
#[allow(clippy::similar_names)]
pub fn fused_q4k_parallel_matvec(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; out_dim];
fused_q4k_parallel_matvec_into(weight_data, activations, in_dim, out_dim, &mut output)?;
Ok(output)
}
#[allow(clippy::similar_names)]
pub fn fused_q4k_parallel_matvec_into(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
output: &mut [f32],
) -> Result<()> {
if activations.len() != in_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4K activation length {} doesn't match in_dim {}",
activations.len(),
in_dim
),
});
}
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let padded_in_dim = super_blocks_per_row * QK_K;
let acts = pad_activations(activations, padded_in_dim);
let use_direct_fp32 = std::env::var("DIRECT_FP32_GEMV").as_deref() == Ok("1");
if use_direct_fp32 {
use rayon::prelude::*;
const SB_BYTES: usize = 144;
let bytes_per_row = super_blocks_per_row * SB_BYTES;
output[..out_dim]
.par_chunks_mut(64)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let row_base = chunk_idx * 64;
for (i, out) in chunk.iter_mut().enumerate() {
let row = row_base + i;
let row_start = row * bytes_per_row;
let row_data = &weight_data[row_start..row_start + bytes_per_row];
*out = super::fused_k::fused_q4k_dot_simd(row_data, &acts).unwrap_or(0.0);
}
});
return Ok(());
}
let num_superblocks = padded_in_dim / QK_K;
const MAX_STACK_DIM: usize = 8960;
const MAX_STACK_SB: usize = (MAX_STACK_DIM + 255) / 256;
if padded_in_dim <= MAX_STACK_DIM {
let mut scales = [0.0f32; MAX_STACK_SB];
let mut quants = [0i8; MAX_STACK_DIM];
super::quantize_activations_q8k_into(
&acts,
&mut scales[..num_superblocks],
&mut quants[..padded_in_dim],
)?;
fused_q4k_q8k_parallel_matvec_into(
weight_data,
&scales[..num_superblocks],
&quants[..padded_in_dim],
in_dim,
out_dim,
output,
)
} else {
let mut scales = vec![0.0f32; num_superblocks];
let mut quants = vec![0i8; padded_in_dim];
super::quantize_activations_q8k_into(&acts, &mut scales, &mut quants)?;
fused_q4k_q8k_parallel_matvec_into(weight_data, &scales, &quants, in_dim, out_dim, output)
}
}
#[allow(dead_code)]
pub fn fused_q4k_preq8k_matvec_into(
weight_data: &[u8],
q8k_scales: &[f32],
q8k_quants: &[i8],
q8k_bsums: &[i16],
in_dim: usize,
out_dim: usize,
output: &mut [f32],
) -> Result<()> {
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let bytes_per_row = super_blocks_per_row * 144;
if weight_data.len() < out_dim * bytes_per_row {
return Err(RealizarError::InvalidShape {
reason: "Weight data too small".to_string(),
});
}
#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
use rayon::prelude::*;
let bpr = bytes_per_row;
let nsb = super_blocks_per_row;
let w_addr = weight_data.as_ptr() as usize;
let sc_addr = q8k_scales.as_ptr() as usize;
let qq_addr = q8k_quants.as_ptr() as usize;
let bs_addr = q8k_bsums.as_ptr() as usize;
output[..out_dim]
.par_chunks_mut(64)
.enumerate()
.for_each(|(ci, chunk)| {
let row_start = ci * 64;
unsafe {
let w = w_addr as *const u8;
let sc = sc_addr as *const f32;
let qq = qq_addr as *const i8;
let bs = bs_addr as *const i16;
for (i, out) in chunk.iter_mut().enumerate() {
let row = row_start + i;
*out = super::fused_k::ggml_style_q4k_q8k_dot_avx2_raw(
w.add(row * bpr),
sc,
qq,
bs,
nsb,
);
}
}
});
return Ok(());
}
fused_q4k_q8k_parallel_matvec_into(weight_data, q8k_scales, q8k_quants, in_dim, out_dim, output)
}
#[allow(dead_code)]
pub fn quantize_for_q4k_matvec(
activations: &[f32],
in_dim: usize,
) -> Result<(Vec<f32>, Vec<i8>, Vec<i16>)> {
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let padded_in_dim = super_blocks_per_row * QK_K;
let num_superblocks = padded_in_dim / QK_K;
let acts = pad_activations(activations, padded_in_dim);
let mut scales = vec![0.0f32; num_superblocks];
let mut quants = vec![0i8; padded_in_dim];
super::quantize_activations_q8k_into(&acts, &mut scales, &mut quants)?;
#[cfg(target_arch = "x86_64")]
let bsums = unsafe { super::fused_k::precompute_q8k_bsums_i16(&quants, num_superblocks) };
#[cfg(not(target_arch = "x86_64"))]
let bsums = vec![0i16; num_superblocks * 16];
Ok((scales, quants, bsums))
}
struct CpuQ8kWorkspace {
scales: Vec<f32>,
quants: Vec<i8>,
}
impl CpuQ8kWorkspace {
const fn empty() -> Self {
Self {
scales: Vec::new(),
quants: Vec::new(),
}
}
fn ensure_capacity(&mut self, padded_in_dim: usize) {
let num_superblocks = padded_in_dim / QK_K;
if self.scales.len() < num_superblocks {
self.scales.resize(num_superblocks, 0.0);
}
if self.quants.len() < padded_in_dim {
self.quants.resize(padded_in_dim, 0);
}
}
fn buffers_mut(&mut self, num_sb: usize, padded_dim: usize) -> (&mut [f32], &mut [i8]) {
(&mut self.scales[..num_sb], &mut self.quants[..padded_dim])
}
}
#[allow(clippy::similar_names)]
pub fn fused_q5k_parallel_matvec(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
) -> Result<Vec<f32>> {
generic_parallel_matvec::<Q5K>(
weight_data,
activations,
in_dim,
out_dim,
fused_q5k_dot_simd,
)
}
include!("q5k_q6k_matvec.rs");
include!("parallel_k_fused_q4k.rs");