#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
use oxibonsai_core::tensor::{BlockQ1_0G128, QK1_0_G128};
#[cfg(target_arch = "aarch64")]
use crate::error::{KernelError, KernelResult};
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn dequant_1bit_g128_neon(
blocks: &[BlockQ1_0G128],
output: &mut [f32],
) -> KernelResult<()> {
let expected_len = blocks.len() * QK1_0_G128;
if output.len() < expected_len {
return Err(KernelError::BufferTooSmall {
needed: expected_len,
available: output.len(),
});
}
for (i, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = vdupq_n_f32(d);
let base = i * QK1_0_G128;
for byte_idx in 0..16 {
let bits = block.qs[byte_idx];
let out_base = base + byte_idx * 8;
let signs_lo = bits_to_signs_neon(bits, 0);
let result_lo = vmulq_f32(scale, signs_lo);
vst1q_f32(output.as_mut_ptr().add(out_base), result_lo);
let signs_hi = bits_to_signs_neon(bits, 4);
let result_hi = vmulq_f32(scale, signs_hi);
vst1q_f32(output.as_mut_ptr().add(out_base + 4), result_hi);
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn gemv_1bit_g128_neon(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for row in 0..n_rows {
let row_blocks = &blocks[row * blocks_per_row..(row + 1) * blocks_per_row];
let mut row_acc = vdupq_n_f32(0.0);
for (bi, block) in row_blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = vdupq_n_f32(d);
let input_base = bi * QK1_0_G128;
for byte_idx in 0..16 {
let bits = block.qs[byte_idx];
let inp_base = input_base + byte_idx * 8;
let inp_lo = vld1q_f32(input.as_ptr().add(inp_base));
let signs_lo = bits_to_signs_neon(bits, 0);
let signed_lo = vmulq_f32(signs_lo, inp_lo);
row_acc = vfmaq_f32(row_acc, scale, signed_lo);
let inp_hi = vld1q_f32(input.as_ptr().add(inp_base + 4));
let signs_hi = bits_to_signs_neon(bits, 4);
let signed_hi = vmulq_f32(signs_hi, inp_hi);
row_acc = vfmaq_f32(row_acc, scale, signed_hi);
}
}
output[row] = hsum_neon(row_acc);
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn gemm_1bit_g128_neon(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < m * k {
return Err(KernelError::DimensionMismatch {
expected: m * k,
got: input.len(),
});
}
if output.len() < m * n_rows {
return Err(KernelError::BufferTooSmall {
needed: m * n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for mi in 0..m {
let input_row = &input[mi * k..];
for ni in 0..n_rows {
let row_blocks = &blocks[ni * blocks_per_row..(ni + 1) * blocks_per_row];
let mut acc = vdupq_n_f32(0.0);
for (bi, block) in row_blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = vdupq_n_f32(d);
let input_base = bi * QK1_0_G128;
for byte_idx in 0..16 {
let bits = block.qs[byte_idx];
let inp_base = input_base + byte_idx * 8;
let inp_lo = vld1q_f32(input_row.as_ptr().add(inp_base));
let signs_lo = bits_to_signs_neon(bits, 0);
let signed_lo = vmulq_f32(signs_lo, inp_lo);
acc = vfmaq_f32(acc, scale, signed_lo);
let inp_hi = vld1q_f32(input_row.as_ptr().add(inp_base + 4));
let signs_hi = bits_to_signs_neon(bits, 4);
let signed_hi = vmulq_f32(signs_hi, inp_hi);
acc = vfmaq_f32(acc, scale, signed_hi);
}
}
output[mi * n_rows + ni] = hsum_neon(acc);
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn bits_to_signs_neon(bits: u8, lane_offset: usize) -> float32x4_t {
let b0 = ((bits >> lane_offset) & 1) as u32;
let b1 = ((bits >> (lane_offset + 1)) & 1) as u32;
let b2 = ((bits >> (lane_offset + 2)) & 1) as u32;
let b3 = ((bits >> (lane_offset + 3)) & 1) as u32;
let mut bit_vec = vdupq_n_u32(0);
bit_vec = vsetq_lane_u32::<0>(b0, bit_vec);
bit_vec = vsetq_lane_u32::<1>(b1, bit_vec);
bit_vec = vsetq_lane_u32::<2>(b2, bit_vec);
bit_vec = vsetq_lane_u32::<3>(b3, bit_vec);
let ones = vdupq_n_u32(1);
let mask = vceqq_u32(bit_vec, ones);
let pos_one = vdupq_n_f32(1.0);
let neg_one = vdupq_n_f32(-1.0);
vbslq_f32(mask, pos_one, neg_one)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn hsum_neon(v: float32x4_t) -> f32 {
let pair = vpaddq_f32(v, v);
let sum = vpaddq_f32(pair, pair);
vgetq_lane_f32::<0>(sum)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn gemv_1bit_g128_neon_prefetch(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for row in 0..n_rows {
let row_blocks = &blocks[row * blocks_per_row..(row + 1) * blocks_per_row];
if row + 1 < n_rows {
let next_row_ptr = blocks.as_ptr().add((row + 1) * blocks_per_row) as *const i8;
core::arch::aarch64::_prefetch(next_row_ptr, 0, 3);
}
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
for (bi, block) in row_blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = vdupq_n_f32(d);
let input_base = bi * QK1_0_G128;
if bi + 1 < blocks_per_row {
let next_block_ptr = row_blocks.as_ptr().add(bi + 1) as *const i8;
core::arch::aarch64::_prefetch(next_block_ptr, 0, 3);
}
let mut byte_idx = 0;
while byte_idx + 1 < 16 {
let bits0 = block.qs[byte_idx];
let bits1 = block.qs[byte_idx + 1];
let inp_base0 = input_base + byte_idx * 8;
let inp_base1 = input_base + (byte_idx + 1) * 8;
let inp0_lo = vld1q_f32(input.as_ptr().add(inp_base0));
let signs0_lo = bits_to_signs_neon(bits0, 0);
let signed0_lo = vmulq_f32(signs0_lo, inp0_lo);
acc0 = vfmaq_f32(acc0, scale, signed0_lo);
let inp0_hi = vld1q_f32(input.as_ptr().add(inp_base0 + 4));
let signs0_hi = bits_to_signs_neon(bits0, 4);
let signed0_hi = vmulq_f32(signs0_hi, inp0_hi);
acc1 = vfmaq_f32(acc1, scale, signed0_hi);
let inp1_lo = vld1q_f32(input.as_ptr().add(inp_base1));
let signs1_lo = bits_to_signs_neon(bits1, 0);
let signed1_lo = vmulq_f32(signs1_lo, inp1_lo);
acc0 = vfmaq_f32(acc0, scale, signed1_lo);
let inp1_hi = vld1q_f32(input.as_ptr().add(inp_base1 + 4));
let signs1_hi = bits_to_signs_neon(bits1, 4);
let signed1_hi = vmulq_f32(signs1_hi, inp1_hi);
acc1 = vfmaq_f32(acc1, scale, signed1_hi);
byte_idx += 2;
}
while byte_idx < 16 {
let bits = block.qs[byte_idx];
let inp_base = input_base + byte_idx * 8;
let inp_lo = vld1q_f32(input.as_ptr().add(inp_base));
let signs_lo = bits_to_signs_neon(bits, 0);
let signed_lo = vmulq_f32(signs_lo, inp_lo);
acc0 = vfmaq_f32(acc0, scale, signed_lo);
let inp_hi = vld1q_f32(input.as_ptr().add(inp_base + 4));
let signs_hi = bits_to_signs_neon(bits, 4);
let signed_hi = vmulq_f32(signs_hi, inp_hi);
acc1 = vfmaq_f32(acc1, scale, signed_hi);
byte_idx += 1;
}
}
let combined = vaddq_f32(acc0, acc1);
output[row] = hsum_neon(combined);
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn gemm_1bit_g128_neon_prefetch(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < m * k {
return Err(KernelError::DimensionMismatch {
expected: m * k,
got: input.len(),
});
}
if output.len() < m * n_rows {
return Err(KernelError::BufferTooSmall {
needed: m * n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
for mi in 0..m {
let input_row = &input[mi * k..];
for ni in 0..n_rows {
let row_blocks = &blocks[ni * blocks_per_row..(ni + 1) * blocks_per_row];
if ni + 1 < n_rows {
let next_ptr = blocks.as_ptr().add((ni + 1) * blocks_per_row) as *const i8;
core::arch::aarch64::_prefetch(next_ptr, 0, 3);
}
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
for (bi, block) in row_blocks.iter().enumerate() {
let d = block.d.to_f32();
let scale = vdupq_n_f32(d);
let input_base = bi * QK1_0_G128;
if bi + 1 < blocks_per_row {
let next_block = row_blocks.as_ptr().add(bi + 1) as *const i8;
core::arch::aarch64::_prefetch(next_block, 0, 3);
}
let mut byte_idx = 0;
while byte_idx + 1 < 16 {
let bits0 = block.qs[byte_idx];
let bits1 = block.qs[byte_idx + 1];
let ib0 = input_base + byte_idx * 8;
let ib1 = input_base + (byte_idx + 1) * 8;
let i0_lo = vld1q_f32(input_row.as_ptr().add(ib0));
let s0_lo = bits_to_signs_neon(bits0, 0);
acc0 = vfmaq_f32(acc0, scale, vmulq_f32(s0_lo, i0_lo));
let i0_hi = vld1q_f32(input_row.as_ptr().add(ib0 + 4));
let s0_hi = bits_to_signs_neon(bits0, 4);
acc1 = vfmaq_f32(acc1, scale, vmulq_f32(s0_hi, i0_hi));
let i1_lo = vld1q_f32(input_row.as_ptr().add(ib1));
let s1_lo = bits_to_signs_neon(bits1, 0);
acc0 = vfmaq_f32(acc0, scale, vmulq_f32(s1_lo, i1_lo));
let i1_hi = vld1q_f32(input_row.as_ptr().add(ib1 + 4));
let s1_hi = bits_to_signs_neon(bits1, 4);
acc1 = vfmaq_f32(acc1, scale, vmulq_f32(s1_hi, i1_hi));
byte_idx += 2;
}
while byte_idx < 16 {
let bits = block.qs[byte_idx];
let ib = input_base + byte_idx * 8;
let i_lo = vld1q_f32(input_row.as_ptr().add(ib));
let s_lo = bits_to_signs_neon(bits, 0);
acc0 = vfmaq_f32(acc0, scale, vmulq_f32(s_lo, i_lo));
let i_hi = vld1q_f32(input_row.as_ptr().add(ib + 4));
let s_hi = bits_to_signs_neon(bits, 4);
acc1 = vfmaq_f32(acc1, scale, vmulq_f32(s_hi, i_hi));
byte_idx += 1;
}
}
let combined = vaddq_f32(acc0, acc1);
output[mi * n_rows + ni] = hsum_neon(combined);
}
}
Ok(())
}
#[cfg(test)]
#[cfg(target_arch = "aarch64")]
mod tests {
use super::*;
use half::f16;
fn make_block(scale: f32, bits: [u8; 16]) -> BlockQ1_0G128 {
BlockQ1_0G128 {
d: f16::from_f32(scale),
qs: bits,
}
}
#[test]
fn test_dequant_neon_all_positive() {
let block = make_block(2.0, [0xFF; 16]);
let mut output = vec![0.0f32; 128];
unsafe {
dequant_1bit_g128_neon(&[block], &mut output).expect("dequant should succeed");
}
for &v in &output {
assert!((v - 2.0).abs() < 0.01, "expected 2.0, got {v}");
}
}
#[test]
fn test_dequant_neon_all_negative() {
let block = make_block(3.0, [0x00; 16]);
let mut output = vec![0.0f32; 128];
unsafe {
dequant_1bit_g128_neon(&[block], &mut output).expect("dequant should succeed");
}
for &v in &output {
assert!((v + 3.0).abs() < 0.01, "expected -3.0, got {v}");
}
}
#[test]
fn test_dequant_neon_matches_reference() {
let block = make_block(1.5, [0xAA; 16]); let mut out_ref = vec![0.0f32; 128];
let mut out_neon = vec![0.0f32; 128];
crate::dequant::dequant_1bit_g128(&[block], &mut out_ref)
.expect("reference dequant should succeed");
unsafe {
dequant_1bit_g128_neon(&[block], &mut out_neon).expect("neon dequant should succeed");
}
for i in 0..128 {
assert!(
(out_ref[i] - out_neon[i]).abs() < 0.01,
"mismatch at {i}: ref={}, neon={}",
out_ref[i],
out_neon[i]
);
}
}
#[test]
fn test_gemv_neon_matches_reference() {
let n_rows = 4;
let k = 256;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for row in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [row as u8 * 37 + bi as u8 * 13; 16];
blocks.push(make_block(0.5 + row as f32 * 0.1, bits));
}
}
let input: Vec<f32> = (0..k).map(|i| (i as f32 * 0.01) - 1.28).collect();
let mut out_ref = vec![0.0f32; n_rows];
let mut out_neon = vec![0.0f32; n_rows];
crate::gemv::gemv_1bit_g128(&blocks, &input, &mut out_ref, n_rows, k)
.expect("reference gemv should succeed");
unsafe {
gemv_1bit_g128_neon(&blocks, &input, &mut out_neon, n_rows, k)
.expect("neon gemv should succeed");
}
for i in 0..n_rows {
assert!(
(out_ref[i] - out_neon[i]).abs() < 0.1,
"row {i}: ref={}, neon={}",
out_ref[i],
out_neon[i]
);
}
}
#[test]
fn test_gemm_neon_matches_reference() {
let m = 2;
let n_rows = 3;
let k = 128;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for ni in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [(ni as u8 * 17 + bi as u8 * 7) | 0x55; 16];
blocks.push(make_block(1.0 + ni as f32 * 0.2, bits));
}
}
let input: Vec<f32> = (0..m * k).map(|i| (i as f32 * 0.005) - 0.32).collect();
let mut out_ref = vec![0.0f32; m * n_rows];
let mut out_neon = vec![0.0f32; m * n_rows];
crate::gemm::gemm_1bit_g128(&blocks, &input, &mut out_ref, m, n_rows, k)
.expect("reference gemm should succeed");
unsafe {
gemm_1bit_g128_neon(&blocks, &input, &mut out_neon, m, n_rows, k)
.expect("neon gemm should succeed");
}
for i in 0..(m * n_rows) {
assert!(
(out_ref[i] - out_neon[i]).abs() < 0.5,
"idx {i}: ref={}, neon={}",
out_ref[i],
out_neon[i]
);
}
}
#[test]
fn test_gemv_neon_prefetch_matches_reference() {
let n_rows = 4;
let k = 256;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for row in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [row as u8 * 37 + bi as u8 * 13; 16];
blocks.push(make_block(0.5 + row as f32 * 0.1, bits));
}
}
let input: Vec<f32> = (0..k).map(|i| (i as f32 * 0.01) - 1.28).collect();
let mut out_ref = vec![0.0f32; n_rows];
let mut out_pf = vec![0.0f32; n_rows];
crate::gemv::gemv_1bit_g128(&blocks, &input, &mut out_ref, n_rows, k)
.expect("reference gemv should succeed");
unsafe {
gemv_1bit_g128_neon_prefetch(&blocks, &input, &mut out_pf, n_rows, k)
.expect("neon prefetch gemv should succeed");
}
for i in 0..n_rows {
assert!(
(out_ref[i] - out_pf[i]).abs() < 0.1,
"row {i}: ref={}, prefetch={}",
out_ref[i],
out_pf[i]
);
}
}
#[test]
fn test_gemv_neon_prefetch_large() {
let n_rows = 64;
let k = 512;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for row in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [((row * 23 + bi * 11) & 0xFF) as u8; 16];
blocks.push(make_block(0.3 + row as f32 * 0.005, bits));
}
}
let input: Vec<f32> = (0..k).map(|i| (i as f32 * 0.005) - 1.0).collect();
let mut out_ref = vec![0.0f32; n_rows];
let mut out_pf = vec![0.0f32; n_rows];
crate::gemv::gemv_1bit_g128(&blocks, &input, &mut out_ref, n_rows, k)
.expect("reference gemv should succeed");
unsafe {
gemv_1bit_g128_neon_prefetch(&blocks, &input, &mut out_pf, n_rows, k)
.expect("neon prefetch gemv should succeed");
}
for i in 0..n_rows {
assert!(
(out_ref[i] - out_pf[i]).abs() < 0.5,
"row {i}: ref={}, prefetch={}",
out_ref[i],
out_pf[i]
);
}
}
#[test]
fn test_gemm_neon_prefetch_matches_reference() {
let m = 2;
let n_rows = 3;
let k = 128;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for ni in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [(ni as u8 * 17 + bi as u8 * 7) | 0x55; 16];
blocks.push(make_block(1.0 + ni as f32 * 0.2, bits));
}
}
let input: Vec<f32> = (0..m * k).map(|i| (i as f32 * 0.005) - 0.32).collect();
let mut out_ref = vec![0.0f32; m * n_rows];
let mut out_pf = vec![0.0f32; m * n_rows];
crate::gemm::gemm_1bit_g128(&blocks, &input, &mut out_ref, m, n_rows, k)
.expect("reference gemm should succeed");
unsafe {
gemm_1bit_g128_neon_prefetch(&blocks, &input, &mut out_pf, m, n_rows, k)
.expect("neon prefetch gemm should succeed");
}
for i in 0..(m * n_rows) {
assert!(
(out_ref[i] - out_pf[i]).abs() < 0.5,
"idx {i}: ref={}, prefetch={}",
out_ref[i],
out_pf[i]
);
}
}
}