#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
use oxibonsai_core::tensor::{BlockQ1_0G128, QK1_0_G128};
#[cfg(target_arch = "x86_64")]
use crate::error::{KernelError, KernelResult};
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn dequant_1bit_g128_avx2(
blocks: &[BlockQ1_0G128],
output: &mut [f32],
) -> KernelResult<()> {
let expected_len = blocks.len() * QK1_0_G128;
if output.len() < expected_len {
return Err(KernelError::BufferTooSmall {
needed: expected_len,
available: output.len(),
});
}
for (i, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = _mm256_set1_ps(d);
let base = i * QK1_0_G128;
for chunk in 0..16 {
let bits = block.qs[chunk];
let out_offset = base + chunk * 8;
let signs = bits_to_signs_avx2(bits);
let result = _mm256_mul_ps(scale, signs);
_mm256_storeu_ps(output.as_mut_ptr().add(out_offset), result);
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn gemv_1bit_g128_avx2(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for row in 0..n_rows {
let row_blocks = &blocks[row * blocks_per_row..(row + 1) * blocks_per_row];
let mut row_acc = _mm256_setzero_ps();
for (bi, block) in row_blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = _mm256_set1_ps(d);
let input_base = bi * QK1_0_G128;
for chunk in 0..16 {
let bits = block.qs[chunk];
let inp_offset = input_base + chunk * 8;
let inp = _mm256_loadu_ps(input.as_ptr().add(inp_offset));
let signs = bits_to_signs_avx2(bits);
let signed_input = _mm256_mul_ps(signs, inp);
row_acc = _mm256_fmadd_ps(scale, signed_input, row_acc);
}
}
output[row] = hsum_avx2(row_acc);
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn gemm_1bit_g128_avx2(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < m * k {
return Err(KernelError::DimensionMismatch {
expected: m * k,
got: input.len(),
});
}
if output.len() < m * n_rows {
return Err(KernelError::BufferTooSmall {
needed: m * n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for mi in 0..m {
let input_row = &input[mi * k..];
for ni in 0..n_rows {
let row_blocks = &blocks[ni * blocks_per_row..(ni + 1) * blocks_per_row];
let mut acc = _mm256_setzero_ps();
for (bi, block) in row_blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = _mm256_set1_ps(d);
let input_base = bi * QK1_0_G128;
for chunk in 0..16 {
let bits = block.qs[chunk];
let inp_offset = input_base + chunk * 8;
let inp = _mm256_loadu_ps(input_row.as_ptr().add(inp_offset));
let signs = bits_to_signs_avx2(bits);
let signed_input = _mm256_mul_ps(signs, inp);
acc = _mm256_fmadd_ps(scale, signed_input, acc);
}
}
output[mi * n_rows + ni] = hsum_avx2(acc);
}
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn bits_to_signs_avx2(bits: u8) -> __m256 {
let bit_masks = _mm256_set_epi32(
((bits >> 7) & 1) as i32,
((bits >> 6) & 1) as i32,
((bits >> 5) & 1) as i32,
((bits >> 4) & 1) as i32,
((bits >> 3) & 1) as i32,
((bits >> 2) & 1) as i32,
((bits >> 1) & 1) as i32,
(bits & 1) as i32,
);
let zero = _mm256_setzero_si256();
let is_zero = _mm256_cmpeq_epi32(bit_masks, zero);
let pos_one = _mm256_set1_ps(1.0);
let neg_one = _mm256_set1_ps(-1.0);
_mm256_blendv_ps(pos_one, neg_one, _mm256_castsi256_ps(is_zero))
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum_avx2(v: __m256) -> f32 {
let hi128 = _mm256_extractf128_ps(v, 1); let lo128 = _mm256_castps256_ps128(v); let sum128 = _mm_add_ps(lo128, hi128); let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let result = _mm_add_ss(sums, shuf2);
_mm_cvtss_f32(result)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn gemv_1bit_g128_avx2_prefetch(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for row in 0..n_rows {
let row_blocks = &blocks[row * blocks_per_row..(row + 1) * blocks_per_row];
if row + 1 < n_rows {
let next_ptr = blocks.as_ptr().add((row + 1) * blocks_per_row) as *const i8;
_mm_prefetch(next_ptr, _MM_HINT_T0);
}
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let pairs = blocks_per_row / 2;
let remainder = blocks_per_row % 2;
for pair_idx in 0..pairs {
let bi0 = pair_idx * 2;
let bi1 = bi0 + 1;
let block0 = &row_blocks[bi0];
let block1 = &row_blocks[bi1];
if bi1 + 1 < blocks_per_row {
let next_ptr = row_blocks.as_ptr().add(bi1 + 1) as *const i8;
_mm_prefetch(next_ptr, _MM_HINT_T0);
}
let d0 = block0.d.to_f32();
let scale0 = _mm256_set1_ps(d0);
let base0 = bi0 * QK1_0_G128;
let d1 = block1.d.to_f32();
let scale1 = _mm256_set1_ps(d1);
let base1 = bi1 * QK1_0_G128;
for chunk in 0..16 {
let bits0 = block0.qs[chunk];
let offset0 = base0 + chunk * 8;
let inp0 = _mm256_loadu_ps(input.as_ptr().add(offset0));
let signs0 = bits_to_signs_avx2(bits0);
let signed0 = _mm256_mul_ps(signs0, inp0);
acc0 = _mm256_fmadd_ps(scale0, signed0, acc0);
let bits1 = block1.qs[chunk];
let offset1 = base1 + chunk * 8;
let inp1 = _mm256_loadu_ps(input.as_ptr().add(offset1));
let signs1 = bits_to_signs_avx2(bits1);
let signed1 = _mm256_mul_ps(signs1, inp1);
acc1 = _mm256_fmadd_ps(scale1, signed1, acc1);
}
}
for (bi, block) in row_blocks
.iter()
.enumerate()
.skip(pairs * 2)
.take(remainder)
{
let d = block.d.to_f32();
let scale = _mm256_set1_ps(d);
let base = bi * QK1_0_G128;
for chunk in 0..16 {
let bits = block.qs[chunk];
let offset = base + chunk * 8;
let inp = _mm256_loadu_ps(input.as_ptr().add(offset));
let signs = bits_to_signs_avx2(bits);
let signed = _mm256_mul_ps(signs, inp);
acc0 = _mm256_fmadd_ps(scale, signed, acc0);
}
}
let combined = _mm256_add_ps(acc0, acc1);
output[row] = hsum_avx2(combined);
}
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn gemm_1bit_g128_avx2_prefetch(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < m * k {
return Err(KernelError::DimensionMismatch {
expected: m * k,
got: input.len(),
});
}
if output.len() < m * n_rows {
return Err(KernelError::BufferTooSmall {
needed: m * n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for mi in 0..m {
let input_row = &input[mi * k..];
for ni in 0..n_rows {
let row_blocks = &blocks[ni * blocks_per_row..(ni + 1) * blocks_per_row];
if ni + 1 < n_rows {
let next_ptr = blocks.as_ptr().add((ni + 1) * blocks_per_row) as *const i8;
_mm_prefetch(next_ptr, _MM_HINT_T0);
}
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let pairs = blocks_per_row / 2;
let remainder = blocks_per_row % 2;
for pair_idx in 0..pairs {
let bi0 = pair_idx * 2;
let bi1 = bi0 + 1;
let block0 = &row_blocks[bi0];
let block1 = &row_blocks[bi1];
let d0 = block0.d.to_f32();
let scale0 = _mm256_set1_ps(d0);
let base0 = bi0 * QK1_0_G128;
let d1 = block1.d.to_f32();
let scale1 = _mm256_set1_ps(d1);
let base1 = bi1 * QK1_0_G128;
for chunk in 0..16 {
let bits0 = block0.qs[chunk];
let inp0 = _mm256_loadu_ps(input_row.as_ptr().add(base0 + chunk * 8));
let signs0 = bits_to_signs_avx2(bits0);
acc0 = _mm256_fmadd_ps(scale0, _mm256_mul_ps(signs0, inp0), acc0);
let bits1 = block1.qs[chunk];
let inp1 = _mm256_loadu_ps(input_row.as_ptr().add(base1 + chunk * 8));
let signs1 = bits_to_signs_avx2(bits1);
acc1 = _mm256_fmadd_ps(scale1, _mm256_mul_ps(signs1, inp1), acc1);
}
}
for (bi, block) in row_blocks
.iter()
.enumerate()
.skip(pairs * 2)
.take(remainder)
{
let d = block.d.to_f32();
let scale = _mm256_set1_ps(d);
let base = bi * QK1_0_G128;
for chunk in 0..16 {
let bits = block.qs[chunk];
let inp = _mm256_loadu_ps(input_row.as_ptr().add(base + chunk * 8));
let signs = bits_to_signs_avx2(bits);
acc0 = _mm256_fmadd_ps(scale, _mm256_mul_ps(signs, inp), acc0);
}
}
let combined = _mm256_add_ps(acc0, acc1);
output[mi * n_rows + ni] = hsum_avx2(combined);
}
}
Ok(())
}
#[cfg(test)]
#[cfg(target_arch = "x86_64")]
mod tests {
use super::*;
use half::f16;
fn make_block(scale: f32, bits: [u8; 16]) -> BlockQ1_0G128 {
BlockQ1_0G128 {
d: f16::from_f32(scale),
qs: bits,
}
}
fn has_avx2() -> bool {
is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
}
#[test]
fn avx2_dequant_all_positive() {
if !has_avx2() {
return;
}
let block = make_block(2.0, [0xFF; 16]);
let mut output = vec![0.0f32; 128];
unsafe {
dequant_1bit_g128_avx2(&[block], &mut output).expect("avx2 dequant should succeed");
}
for &v in &output {
assert!((v - 2.0).abs() < 0.01, "expected 2.0, got {v}");
}
}
#[test]
fn avx2_dequant_all_negative() {
if !has_avx2() {
return;
}
let block = make_block(3.0, [0x00; 16]);
let mut output = vec![0.0f32; 128];
unsafe {
dequant_1bit_g128_avx2(&[block], &mut output).expect("avx2 dequant should succeed");
}
for &v in &output {
assert!((v + 3.0).abs() < 0.01, "expected -3.0, got {v}");
}
}
#[test]
fn avx2_dequant_matches_reference() {
if !has_avx2() {
return;
}
let block = make_block(1.5, [0xAA; 16]); let mut out_ref = vec![0.0f32; 128];
let mut out_avx = vec![0.0f32; 128];
crate::dequant::dequant_1bit_g128(&[block], &mut out_ref)
.expect("reference dequant should succeed");
unsafe {
dequant_1bit_g128_avx2(&[block], &mut out_avx).expect("avx2 dequant should succeed");
}
for i in 0..128 {
assert!(
(out_ref[i] - out_avx[i]).abs() < 0.01,
"mismatch at {i}: ref={}, avx2={}",
out_ref[i],
out_avx[i]
);
}
}
#[test]
fn avx2_gemv_identity_like() {
if !has_avx2() {
return;
}
let blocks = vec![make_block(1.0, [0xFF; 16])];
let input: Vec<f32> = (0..128).map(|i| i as f32).collect();
let mut output = vec![0.0f32; 1];
unsafe {
gemv_1bit_g128_avx2(&blocks, &input, &mut output, 1, 128)
.expect("avx2 gemv should succeed");
}
let expected: f32 = (0..128).map(|i| i as f32).sum();
assert!(
(output[0] - expected).abs() < 1.0,
"expected ~{expected}, got {}",
output[0]
);
}
#[test]
fn avx2_gemv_matches_reference() {
if !has_avx2() {
return;
}
let n_rows = 4;
let k = 256;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for row in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [(row as u8 * 37 + bi as u8 * 13) & 0xFF; 16];
blocks.push(make_block(0.5 + row as f32 * 0.1, bits));
}
}
let input: Vec<f32> = (0..k).map(|i| (i as f32 * 0.01) - 1.28).collect();
let mut out_ref = vec![0.0f32; n_rows];
let mut out_avx = vec![0.0f32; n_rows];
crate::gemv::gemv_1bit_g128(&blocks, &input, &mut out_ref, n_rows, k)
.expect("reference gemv should succeed");
unsafe {
gemv_1bit_g128_avx2(&blocks, &input, &mut out_avx, n_rows, k)
.expect("avx2 gemv should succeed");
}
for i in 0..n_rows {
assert!(
(out_ref[i] - out_avx[i]).abs() < 0.1,
"row {i}: ref={}, avx2={}",
out_ref[i],
out_avx[i]
);
}
}
#[test]
fn avx2_gemm_matches_reference() {
if !has_avx2() {
return;
}
let m = 2;
let n_rows = 3;
let k = 128;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for ni in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [(ni as u8 * 17 + bi as u8 * 7) | 0x55; 16];
blocks.push(make_block(1.0 + ni as f32 * 0.2, bits));
}
}
let input: Vec<f32> = (0..m * k).map(|i| (i as f32 * 0.005) - 0.32).collect();
let mut out_ref = vec![0.0f32; m * n_rows];
let mut out_avx = vec![0.0f32; m * n_rows];
crate::gemm::gemm_1bit_g128(&blocks, &input, &mut out_ref, m, n_rows, k)
.expect("reference gemm should succeed");
unsafe {
gemm_1bit_g128_avx2(&blocks, &input, &mut out_avx, m, n_rows, k)
.expect("avx2 gemm should succeed");
}
for i in 0..(m * n_rows) {
assert!(
(out_ref[i] - out_avx[i]).abs() < 0.5,
"idx {i}: ref={}, avx2={}",
out_ref[i],
out_avx[i]
);
}
}
}