mod scalar;
#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
mod avx512;
use super::{SUPER_BLOCK_BYTES, SUPER_BLOCK_SIZE};
pub use scalar::{matmul_q4k_f32, matmul_q4k_f32_scalar};
#[allow(unused_imports)]
pub(crate) use scalar::compute_chunk_q4k_scalar;
#[inline]
pub fn matmul_q4k_f32_dispatch(
q4k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
debug_assert_eq!(input.len(), in_dim, "Q4K dispatch: input length mismatch");
debug_assert!(
q4k_data.len() >= crate::contracts::Q4_K.expected_bytes(out_dim, in_dim),
"Q4K dispatch: buffer too small: {} bytes for [{}, {}] (need {})",
q4k_data.len(),
out_dim,
in_dim,
crate::contracts::Q4_K.expected_bytes(out_dim, in_dim),
);
#[cfg(target_arch = "x86_64")]
{
let total_work = out_dim * in_dim;
if total_work >= 8_000_000 {
return matmul_q4k_f32_parallel(q4k_data, input, out_dim, in_dim);
}
if is_x86_feature_detected!("avx512f")
&& is_x86_feature_detected!("avx512bw")
&& is_x86_feature_detected!("fma")
{
return unsafe { avx512::matmul_q4k_f32_avx512(q4k_data, input, out_dim, in_dim) };
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { avx2::matmul_q4k_f32_avx2(q4k_data, input, out_dim, in_dim) };
}
}
scalar::matmul_q4k_f32(q4k_data, input, out_dim, in_dim)
}
#[cfg(target_arch = "x86_64")]
fn matmul_q4k_f32_parallel(
q4k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
use std::thread;
let num_threads = thread::available_parallelism().map(|p| p.get()).unwrap_or(4).min(12);
let chunk_size = (out_dim + num_threads - 1) / num_threads;
let num_blocks_per_row = (in_dim + SUPER_BLOCK_SIZE - 1) / SUPER_BLOCK_SIZE;
let row_bytes = num_blocks_per_row * SUPER_BLOCK_BYTES;
let mut output: Vec<f32> = Vec::with_capacity(out_dim);
unsafe {
output.set_len(out_dim);
}
let has_avx512 = is_x86_feature_detected!("avx512f")
&& is_x86_feature_detected!("avx512bw")
&& is_x86_feature_detected!("fma");
let has_avx2 = is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma");
thread::scope(|s| {
let input_ref = input;
let q4k_ref = q4k_data;
for (chunk_idx, chunk) in output.chunks_mut(chunk_size).enumerate() {
let start_row = chunk_idx * chunk_size;
s.spawn(move || {
if has_avx512 {
unsafe {
avx512::compute_chunk_q4k_avx512(
q4k_ref,
input_ref,
chunk,
start_row,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
}
} else if has_avx2 {
unsafe {
avx2::compute_chunk_q4k_avx2(
q4k_ref,
input_ref,
chunk,
start_row,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
}
} else {
scalar::compute_chunk_q4k_scalar(
q4k_ref,
input_ref,
chunk,
start_row,
out_dim,
in_dim,
num_blocks_per_row,
row_bytes,
);
}
});
}
});
output
}
#[cfg(not(target_arch = "x86_64"))]
fn matmul_q4k_f32_parallel(
q4k_data: &[u8],
input: &[f32],
out_dim: usize,
in_dim: usize,
) -> Vec<f32> {
scalar::matmul_q4k_f32(q4k_data, input, out_dim, in_dim)
}