use crate::error::BitTTTError;
use crate::kernels::packing::PackedTensor;
use candle_core::{Result, Tensor};
use rayon::prelude::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[derive(Debug, Clone)]
pub struct BitLinearCpu;
impl BitLinearCpu {
pub fn forward(input: &Tensor, weights: &PackedTensor) -> Result<Tensor> {
let (m, k) = input.dims2()?;
let (n, k_w) = weights.shape.dims2()?;
if k != k_w {
return Err(BitTTTError::shape_mismatch(format!(
"Input [{}, {}] vs Weight [{}, {}]",
m, k, n, k_w
))
.into());
}
if weights.is_multibase() {
let w_dequant = weights.unpack(&candle_core::Device::Cpu)?;
let w_t = w_dequant.t()?;
return input.matmul(&w_t);
}
let x_vec = input.flatten_all()?.to_vec1::<f32>()?;
let (w_storage, w_layout) = weights.data.storage_and_layout();
let w_slice = match &*w_storage {
candle_core::Storage::Cpu(storage) => storage.as_slice::<u8>()?,
_ => {
return Err(BitTTTError::storage_error(
"BitLinearCpu: Weights must be on CPU storage",
)
.into())
}
};
if !w_layout.is_contiguous() {
return Err(
BitTTTError::storage_error("BitLinearCpu: Weights must be contiguous").into(),
);
}
let output_len = m * n;
let mut output = vec![0.0f32; output_len];
const LUT: [f32; 4] = [0.0, 1.0, -1.0, 0.0];
#[cfg(target_arch = "x86_64")]
let has_avx2 = is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma");
#[cfg(not(target_arch = "x86_64"))]
let has_avx2 = false;
output
.par_iter_mut()
.enumerate()
.for_each(|(global_idx, out_val)| {
let i = global_idx / n; let j = global_idx % n;
let mut sum = 0.0f32;
let w_row_start = j * k.div_ceil(4);
let x_row_start = i * k;
let mut processed = 0;
if has_avx2 {
let chunk_size = 32;
let num_chunks = k / chunk_size;
#[cfg(target_arch = "x86_64")]
unsafe {
sum += compute_row_avx2(
&x_vec[x_row_start..],
&w_slice[w_row_start..],
num_chunks,
);
}
processed = num_chunks * chunk_size;
}
for l in processed..k {
let x_val = unsafe { *x_vec.get_unchecked(x_row_start + l) };
let byte_idx = l / 4;
let bit_idx = l % 4;
if w_row_start + byte_idx >= w_slice.len() {
break;
}
let byte = unsafe { *w_slice.get_unchecked(w_row_start + byte_idx) };
let code = (byte >> (bit_idx * 2)) & 0b11;
let coeff = unsafe { *LUT.get_unchecked(code as usize) };
sum += x_val * coeff;
}
*out_val = sum * weights.scale;
});
Tensor::from_vec(output, (m, n), &candle_core::Device::Cpu)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum256_ps(v: __m256) -> f32 {
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(high, low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn compute_row_avx2(x_ptr: &[f32], w_ptr: &[u8], num_chunks: usize) -> f32 {
let mut sum_vec = _mm256_setzero_ps();
let mut x_curr = x_ptr.as_ptr();
let mut w_curr = w_ptr.as_ptr();
for _ in 0..num_chunks {
for _ in 0..4 {
let w_val = *(w_curr as *const u16);
w_curr = w_curr.add(2);
let mut coeffs = [0.0f32; 8];
for (b, coeff) in coeffs.iter_mut().enumerate() {
let shift = b * 2;
let code = (w_val >> shift) & 0x03;
let val = ((code & 1) as i32) - ((code >> 1) as i32);
*coeff = val as f32;
}
let w_vec = _mm256_loadu_ps(coeffs.as_ptr());
let x_vec = _mm256_loadu_ps(x_curr);
x_curr = x_curr.add(8);
sum_vec = _mm256_fmadd_ps(x_vec, w_vec, sum_vec);
}
}
hsum256_ps(sum_vec)
}