use super::{Backend, ComputeOp};
use crate::error::TruenoError;
#[derive(Debug, Clone)]
pub struct BlockQ5K {
pub d: f32,
pub dmin: f32,
pub scales: [u8; 12],
pub qh: [u8; 32],
pub qs: [u8; 128],
}
impl BlockQ5K {
pub const BLOCK_SIZE: usize = 256;
pub fn dequantize(&self, output: &mut [f32]) {
debug_assert!(output.len() >= Self::BLOCK_SIZE);
let mut scales = [0i8; 8];
for i in 0..8 {
let low = (self.scales[i] & 0x3F) as i8;
scales[i] = low - 32;
}
for block_idx in 0..8 {
let scale = scales[block_idx] as f32;
let base_idx = block_idx * 32;
for i in 0..32 {
let out_idx = base_idx + i;
let byte_idx = base_idx / 2 + i / 2;
let q4 = if i % 2 == 0 { self.qs[byte_idx] & 0x0F } else { self.qs[byte_idx] >> 4 };
let qh_bit = ((self.qh[i] >> block_idx) & 1) as u8;
let q5 = q4 | (qh_bit << 4);
output[out_idx] = self.d * scale * (q5 as f32 - 16.0) + self.dmin;
}
}
}
}
#[derive(Debug, Clone)]
pub struct BlockQ6K {
pub ql: [u8; 128],
pub qh: [u8; 64],
pub scales: [i8; 16],
pub d: f32,
}
impl BlockQ6K {
pub const BLOCK_SIZE: usize = 256;
pub fn dequantize(&self, output: &mut [f32]) {
debug_assert!(output.len() >= Self::BLOCK_SIZE);
for block_idx in 0..16 {
let scale = self.scales[block_idx] as f32;
let base_idx = block_idx * 16;
for i in 0..16 {
let out_idx = base_idx + i;
let ql_idx = base_idx / 2 + i / 2;
let qh_idx = base_idx / 4 + i / 4;
let ql_val = if i % 2 == 0 { self.ql[ql_idx] & 0x0F } else { self.ql[ql_idx] >> 4 };
let qh_shift = (i % 4) * 2;
let qh_val = ((self.qh[qh_idx] >> qh_shift) & 0x03) as u8;
let q6 = ql_val | (qh_val << 4);
output[out_idx] = self.d * scale * (q6 as f32 - 32.0);
}
}
}
}
#[derive(Debug, Clone)]
pub struct DotQ5KOp {
pub n_blocks: usize,
}
impl DotQ5KOp {
#[must_use]
pub fn new(n_elements: usize) -> Self {
Self { n_blocks: n_elements / BlockQ5K::BLOCK_SIZE }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn avx2_dot_block(block: &BlockQ5K, x: &[f32]) -> f32 {
unsafe {
use std::arch::x86_64::*;
let mut acc = _mm256_setzero_ps();
let mut dequant = [0.0f32; BlockQ5K::BLOCK_SIZE];
block.dequantize(&mut dequant);
let mut i = 0;
while i + 8 <= BlockQ5K::BLOCK_SIZE {
let vd = _mm256_loadu_ps(dequant.as_ptr().add(i));
let vx = _mm256_loadu_ps(x.as_ptr().add(i));
acc = _mm256_fmadd_ps(vd, vx, acc);
i += 8;
}
let high = _mm256_extractf128_ps(acc, 1);
let low = _mm256_castps256_ps128(acc);
let sum128 = _mm_add_ps(high, low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
}
}
}
impl ComputeOp for DotQ5KOp {
type Input = (Vec<BlockQ5K>, Vec<f32>);
type Output = f32;
fn name(&self) -> &'static str {
"dot_q5k"
}
fn execute(&self, input: Self::Input, backend: Backend) -> Result<Self::Output, TruenoError> {
let (blocks, x) = input;
if blocks.is_empty() || x.is_empty() {
return Ok(0.0);
}
let mut sum = 0.0f32;
#[cfg(target_arch = "x86_64")]
{
if matches!(backend, Backend::Avx2 | Backend::Auto) && is_x86_feature_detected!("avx2")
{
for (i, block) in blocks.iter().enumerate() {
let x_slice = &x[i * BlockQ5K::BLOCK_SIZE..];
sum += unsafe { Self::avx2_dot_block(block, x_slice) };
}
return Ok(sum);
}
}
let mut dequant = [0.0f32; BlockQ5K::BLOCK_SIZE];
for (i, block) in blocks.iter().enumerate() {
block.dequantize(&mut dequant);
let x_slice = &x[i * BlockQ5K::BLOCK_SIZE..];
for j in 0..BlockQ5K::BLOCK_SIZE {
sum += dequant[j] * x_slice[j];
}
}
Ok(sum)
}
fn tokens(&self, _input: &Self::Input) -> usize {
self.n_blocks * BlockQ5K::BLOCK_SIZE
}
}
#[derive(Debug, Clone)]
pub struct DotQ6KOp {
pub n_blocks: usize,
}
impl DotQ6KOp {
#[must_use]
pub fn new(n_elements: usize) -> Self {
Self { n_blocks: n_elements / BlockQ6K::BLOCK_SIZE }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn avx2_dot_block(block: &BlockQ6K, x: &[f32]) -> f32 {
unsafe {
use std::arch::x86_64::*;
let mut acc = _mm256_setzero_ps();
let mut dequant = [0.0f32; BlockQ6K::BLOCK_SIZE];
block.dequantize(&mut dequant);
let mut i = 0;
while i + 8 <= BlockQ6K::BLOCK_SIZE {
let vd = _mm256_loadu_ps(dequant.as_ptr().add(i));
let vx = _mm256_loadu_ps(x.as_ptr().add(i));
acc = _mm256_fmadd_ps(vd, vx, acc);
i += 8;
}
let high = _mm256_extractf128_ps(acc, 1);
let low = _mm256_castps256_ps128(acc);
let sum128 = _mm_add_ps(high, low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
}
}
}
impl ComputeOp for DotQ6KOp {
type Input = (Vec<BlockQ6K>, Vec<f32>);
type Output = f32;
fn name(&self) -> &'static str {
"dot_q6k"
}
fn execute(&self, input: Self::Input, backend: Backend) -> Result<Self::Output, TruenoError> {
let (blocks, x) = input;
if blocks.is_empty() || x.is_empty() {
return Ok(0.0);
}
let mut sum = 0.0f32;
#[cfg(target_arch = "x86_64")]
{
if matches!(backend, Backend::Avx2 | Backend::Auto) && is_x86_feature_detected!("avx2")
{
for (i, block) in blocks.iter().enumerate() {
let x_slice = &x[i * BlockQ6K::BLOCK_SIZE..];
sum += unsafe { Self::avx2_dot_block(block, x_slice) };
}
return Ok(sum);
}
}
let mut dequant = [0.0f32; BlockQ6K::BLOCK_SIZE];
for (i, block) in blocks.iter().enumerate() {
block.dequantize(&mut dequant);
let x_slice = &x[i * BlockQ6K::BLOCK_SIZE..];
for j in 0..BlockQ6K::BLOCK_SIZE {
sum += dequant[j] * x_slice[j];
}
}
Ok(sum)
}
fn tokens(&self, _input: &Self::Input) -> usize {
self.n_blocks * BlockQ6K::BLOCK_SIZE
}
}
#[cfg(test)]
pub mod nf4;
mod tests;