use super::ternary_tensor::TernaryTensor;
const BLOCK_SIZE: usize = 256;
#[inline]
pub fn absmax_quantize_activations(input: &[f32]) -> (Vec<i8>, f32) {
if input.is_empty() {
return (vec![], 1.0);
}
let abs_max = input.iter().fold(0.0f32, |acc, &x| acc.max(x.abs()));
if abs_max < 1e-10 {
return (vec![0i8; input.len()], 1.0);
}
let scale = 127.0 / abs_max;
let quantized: Vec<i8> = input
.iter()
.map(|&x| {
let scaled = x * scale;
scaled.round().clamp(-127.0, 127.0) as i8
})
.collect();
(quantized, scale)
}
#[inline]
pub fn generate_tl1_lut(weights_pair: (i8, i8)) -> [i16; 256] {
let (w0, w1) = weights_pair;
let mut lut = [0i16; 256];
for i in 0u16..256 {
let act_val = i as u8 as i8;
lut[i as usize] = (w0 as i16) * (act_val as i16) + (w1 as i16) * (act_val as i16);
}
lut
}
#[inline(always)]
fn decode_ternary_2bit(bits: u8) -> i8 {
match bits & 0x03 {
0b00 => -1,
0b01 => 0,
0b10 => 1,
_ => 0, }
}
#[inline]
fn tl1_gemv_scalar(
packed: &[u8],
scales: &[f32],
act_i8: &[i8],
act_scale: f32,
out_features: usize,
in_features: usize,
output: &mut [f32],
) {
if act_scale.abs() < 1e-30 {
for v in output.iter_mut() {
*v = 0.0;
}
return;
}
let packed_cols = (in_features + 3) / 4;
for row in 0..out_features {
let row_packed_start = row * packed_cols;
let mut acc = 0i32;
for col in 0..in_features {
let byte_idx = row_packed_start + col / 4;
if byte_idx >= packed.len() {
break;
}
let bit_offset = (col % 4) * 2;
let encoded = (packed[byte_idx] >> bit_offset) & 0x03;
let weight = decode_ternary_2bit(encoded);
acc += (weight as i32) * (act_i8[col] as i32);
}
let flat_offset = row * in_features;
let block_idx = flat_offset / BLOCK_SIZE;
let weight_scale = scales.get(block_idx).copied().unwrap_or(1.0);
output[row] = (acc as f32) * weight_scale / act_scale;
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
unsafe fn tl1_gemv_neon(
packed: &[u8],
scales: &[f32],
act_i8: &[i8],
act_scale: f32,
out_features: usize,
in_features: usize,
output: &mut [f32],
) {
use std::arch::aarch64::*;
let packed_cols = (in_features + 3) / 4;
for row in 0..out_features {
let row_packed_start = row * packed_cols;
let mut acc0 = vdupq_n_s32(0);
let mut acc1 = vdupq_n_s32(0);
let chunks_16 = in_features / 16;
let mut col = 0usize;
for _ in 0..chunks_16 {
let packed_offset = row_packed_start + col / 4;
let b0 = *packed.get_unchecked(packed_offset);
let b1 = *packed.get_unchecked(packed_offset + 1);
let b2 = *packed.get_unchecked(packed_offset + 2);
let b3 = *packed.get_unchecked(packed_offset + 3);
let mut w = [0i8; 16];
let bytes = [b0, b1, b2, b3];
for (bi, &byte_val) in bytes.iter().enumerate() {
for vi in 0..4 {
let encoded = (byte_val >> (vi * 2)) & 0x03;
w[bi * 4 + vi] = decode_ternary_2bit(encoded);
}
}
let w_vec = vld1q_s8(w.as_ptr());
let a_vec = vld1q_s8(act_i8.as_ptr().add(col));
let w_lo = vmovl_s8(vget_low_s8(w_vec)); let w_hi = vmovl_s8(vget_high_s8(w_vec)); let a_lo = vmovl_s8(vget_low_s8(a_vec)); let a_hi = vmovl_s8(vget_high_s8(a_vec));
let prod_lo = vmulq_s16(w_lo, a_lo); let prod_hi = vmulq_s16(w_hi, a_hi);
let prod_lo_lo = vmovl_s16(vget_low_s16(prod_lo)); let prod_lo_hi = vmovl_s16(vget_high_s16(prod_lo)); let prod_hi_lo = vmovl_s16(vget_low_s16(prod_hi)); let prod_hi_hi = vmovl_s16(vget_high_s16(prod_hi));
acc0 = vaddq_s32(acc0, prod_lo_lo);
acc0 = vaddq_s32(acc0, prod_lo_hi);
acc1 = vaddq_s32(acc1, prod_hi_lo);
acc1 = vaddq_s32(acc1, prod_hi_hi);
col += 16;
}
let combined = vaddq_s32(acc0, acc1);
let acc_i32 = vaddvq_s32(combined);
let mut scalar_acc = acc_i32;
for c in col..in_features {
let byte_idx = row_packed_start + c / 4;
let bit_offset = (c % 4) * 2;
let encoded = (*packed.get_unchecked(byte_idx) >> bit_offset) & 0x03;
let weight = decode_ternary_2bit(encoded);
scalar_acc += (weight as i32) * (*act_i8.get_unchecked(c) as i32);
}
let flat_offset = row * in_features;
let block_idx = flat_offset / BLOCK_SIZE;
let weight_scale = scales.get(block_idx).copied().unwrap_or(1.0);
output[row] = (scalar_acc as f32) * weight_scale / act_scale;
}
}
pub fn tl1_gemv(weights: &TernaryTensor, activations: &[f32], output: &mut [f32]) {
let (out_features, in_features) = weights.shape;
assert_eq!(
activations.len(),
in_features,
"Activation length {} does not match weight columns {}",
activations.len(),
in_features
);
assert_eq!(
output.len(),
out_features,
"Output length {} does not match weight rows {}",
output.len(),
out_features
);
let (act_i8, act_scale) = absmax_quantize_activations(activations);
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
unsafe {
tl1_gemv_neon(
&weights.packed_data,
&weights.scales,
&act_i8,
act_scale,
out_features,
in_features,
output,
);
}
return;
}
#[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))]
{
tl1_gemv_scalar(
&weights.packed_data,
&weights.scales,
&act_i8,
act_scale,
out_features,
in_features,
output,
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bitnet::{absmean_ternary, pack_ternary, TernaryTensor};
const EPSILON: f32 = 1e-4;
#[test]
fn test_lut_generation_identity_weights() {
let lut = generate_tl1_lut((1, 1));
assert_eq!(lut[1], 2, "(1,1) with act=1 should give 2");
assert_eq!(lut[127], 254, "(1,1) with act=127 should give 254");
assert_eq!(lut[255], -2, "(1,1) with act=-1 should give -2");
}
#[test]
fn test_lut_generation_opposite_weights() {
let lut = generate_tl1_lut((1, -1));
for i in 0..256 {
assert_eq!(lut[i], 0, "(1,-1) should always give 0");
}
}
#[test]
fn test_lut_generation_zero_weights() {
let lut = generate_tl1_lut((0, 0));
for i in 0..256 {
assert_eq!(lut[i], 0, "(0,0) should always give 0");
}
}
#[test]
fn test_lut_generation_single_weight() {
let lut = generate_tl1_lut((1, 0));
assert_eq!(lut[1], 1);
assert_eq!(lut[127], 127);
assert_eq!(lut[255], -1);
assert_eq!(lut[128], -128);
}
#[test]
fn test_lut_generation_negative_weight() {
let lut = generate_tl1_lut((-1, 0));
assert_eq!(lut[1], -1);
assert_eq!(lut[127], -127);
assert_eq!(lut[255], 1); }
#[test]
fn test_absmax_quantize_preserves_sign() {
let input = vec![1.0, -1.0, 0.5, -0.5];
let (q, _scale) = absmax_quantize_activations(&input);
assert!(q[0] > 0, "Positive input should quantize to positive");
assert!(q[1] < 0, "Negative input should quantize to negative");
assert!(q[2] > 0, "Positive input should quantize to positive");
assert!(q[3] < 0, "Negative input should quantize to negative");
}
#[test]
fn test_absmax_quantize_relative_magnitude() {
let input = vec![1.0, 0.5, 0.25];
let (q, _scale) = absmax_quantize_activations(&input);
assert_eq!(q[0], 127);
assert!(
(q[1] as i32 - 64).abs() <= 1,
"0.5 should map to ~64, got {}",
q[1]
);
assert!(
(q[2] as i32 - 32).abs() <= 1,
"0.25 should map to ~32, got {}",
q[2]
);
}
#[test]
fn test_absmax_quantize_all_zeros() {
let input = vec![0.0; 16];
let (q, scale) = absmax_quantize_activations(&input);
assert!(
q.iter().all(|&x| x == 0),
"All-zero input should give all-zero output"
);
assert_eq!(scale, 1.0, "Scale for all-zero should be 1.0");
}
#[test]
fn test_absmax_quantize_empty() {
let input: Vec<f32> = vec![];
let (q, scale) = absmax_quantize_activations(&input);
assert!(q.is_empty());
assert_eq!(scale, 1.0);
}
#[test]
fn test_absmax_quantize_single_element() {
let input = vec![3.14];
let (q, scale) = absmax_quantize_activations(&input);
assert_eq!(q[0], 127, "Single positive element should map to 127");
let expected_scale = 127.0 / 3.14;
assert!(
(scale - expected_scale).abs() < EPSILON,
"Scale mismatch: expected {}, got {}",
expected_scale,
scale
);
}
#[test]
fn test_absmax_quantize_negative_dominant() {
let input = vec![-10.0, 1.0, -5.0, 0.5];
let (q, scale) = absmax_quantize_activations(&input);
assert_eq!(q[0], -127, "-10.0 should map to -127");
let expected_scale = 127.0 / 10.0;
assert!(
(scale - expected_scale).abs() < EPSILON,
"Scale should be 127/10"
);
}
#[test]
fn test_scalar_gemv_identity_row() {
let weights_i8 = vec![1i8, 1, 1, 1];
let packed = pack_ternary(&weights_i8);
let scales = vec![1.0f32];
let activations = vec![1.0, 2.0, 3.0, 4.0];
let (act_i8, act_scale) = absmax_quantize_activations(&activations);
let mut output = vec![0.0f32; 1];
tl1_gemv_scalar(&packed, &scales, &act_i8, act_scale, 1, 4, &mut output);
let expected = 10.0;
assert!(
(output[0] - expected).abs() < 0.5,
"Identity row GEMV: expected ~{}, got {}",
expected,
output[0]
);
}
#[test]
fn test_scalar_gemv_negation_row() {
let weights_i8 = vec![-1i8, -1, -1, -1];
let packed = pack_ternary(&weights_i8);
let scales = vec![1.0f32];
let activations = vec![1.0, 2.0, 3.0, 4.0];
let (act_i8, act_scale) = absmax_quantize_activations(&activations);
let mut output = vec![0.0f32; 1];
tl1_gemv_scalar(&packed, &scales, &act_i8, act_scale, 1, 4, &mut output);
let expected = -10.0;
assert!(
(output[0] - expected).abs() < 0.5,
"Negation row GEMV: expected ~{}, got {}",
expected,
output[0]
);
}
#[test]
fn test_scalar_gemv_zero_weights() {
let weights_i8 = vec![0i8; 8];
let packed = pack_ternary(&weights_i8);
let scales = vec![1.0f32];
let activations = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let (act_i8, act_scale) = absmax_quantize_activations(&activations);
let mut output = vec![0.0f32; 1];
tl1_gemv_scalar(&packed, &scales, &act_i8, act_scale, 1, 8, &mut output);
assert!(
output[0].abs() < EPSILON,
"Zero weights should give zero output, got {}",
output[0]
);
}
#[test]
fn test_scalar_gemv_zero_activations() {
let weights_i8 = vec![1i8, -1, 1, -1];
let packed = pack_ternary(&weights_i8);
let scales = vec![1.0f32];
let activations = vec![0.0; 4];
let (act_i8, act_scale) = absmax_quantize_activations(&activations);
let mut output = vec![0.0f32; 1];
tl1_gemv_scalar(&packed, &scales, &act_i8, act_scale, 1, 4, &mut output);
assert!(
output[0].abs() < EPSILON,
"Zero activations should give zero output, got {}",
output[0]
);
}
#[test]
fn test_scalar_gemv_multiple_rows() {
let weights_i8 = vec![1i8, 1, 1, 1, -1, -1, -1, -1];
let packed = pack_ternary(&weights_i8);
let scales = vec![1.0f32];
let activations = vec![1.0, 2.0, 3.0, 4.0];
let (act_i8, act_scale) = absmax_quantize_activations(&activations);
let mut output = vec![0.0f32; 2];
tl1_gemv_scalar(&packed, &scales, &act_i8, act_scale, 2, 4, &mut output);
assert!(
(output[0] - 10.0).abs() < 0.5,
"Row 0: expected ~10.0, got {}",
output[0]
);
assert!(
(output[1] - (-10.0)).abs() < 0.5,
"Row 1: expected ~-10.0, got {}",
output[1]
);
}
#[test]
fn test_tl1_gemv_roundtrip_simple() {
let fp32_weights = vec![0.5f32; 16]; let shape = (4, 4);
let (ternary_vals, scale) = absmean_ternary(&fp32_weights);
let packed = pack_ternary(&ternary_vals);
let weights = TernaryTensor {
packed_data: packed,
scales: vec![scale],
shape,
block_size: BLOCK_SIZE,
};
let activations = vec![1.0f32; 4];
let mut output = vec![0.0f32; 4];
tl1_gemv(&weights, &activations, &mut output);
for (i, &val) in output.iter().enumerate() {
assert!(
(val - 2.0).abs() < 0.5,
"Row {}: expected ~2.0, got {}",
i,
val
);
}
}
#[test]
fn test_tl1_gemv_vs_fp32_reference() {
let out_features = 4;
let in_features = 8;
let ternary_vals = vec![
1i8, 0, -1, 1, 0, 1, -1, 0, -1, 1, 0, -1, 1, 0, 1, -1, 0, 0, 1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, ];
let packed = pack_ternary(&ternary_vals);
let weight_scale = 0.5f32;
let weights = TernaryTensor {
packed_data: packed,
scales: vec![weight_scale],
shape: (out_features, in_features),
block_size: BLOCK_SIZE,
};
let activations = vec![1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 1.5, -1.5];
let mut output = vec![0.0f32; out_features];
tl1_gemv(&weights, &activations, &mut output);
let mut reference = vec![0.0f32; out_features];
for r in 0..out_features {
let mut dot = 0.0f32;
for c in 0..in_features {
dot += (ternary_vals[r * in_features + c] as f32) * activations[c];
}
reference[r] = dot * weight_scale;
}
for (i, (&out, &ref_val)) in output.iter().zip(reference.iter()).enumerate() {
let abs_tol = 0.3 + ref_val.abs() * 0.05; assert!(
(out - ref_val).abs() < abs_tol,
"Row {}: TL1={:.4}, ref={:.4}, diff={:.4}, tol={:.4}",
i,
out,
ref_val,
(out - ref_val).abs(),
abs_tol
);
}
}
#[test]
fn test_tl1_gemv_single_element() {
let weights_i8 = vec![1i8];
let packed = pack_ternary(&weights_i8);
let scale = 2.0f32;
let weights = TernaryTensor {
packed_data: packed,
scales: vec![scale],
shape: (1, 1),
block_size: BLOCK_SIZE,
};
let activations = vec![3.0f32];
let mut output = vec![0.0f32; 1];
tl1_gemv(&weights, &activations, &mut output);
assert!(
(output[0] - 6.0).abs() < 0.5,
"Single element: expected ~6.0, got {}",
output[0]
);
}
#[test]
fn test_decode_ternary_2bit_values() {
assert_eq!(decode_ternary_2bit(0b00), -1);
assert_eq!(decode_ternary_2bit(0b01), 0);
assert_eq!(decode_ternary_2bit(0b10), 1);
assert_eq!(decode_ternary_2bit(0b11), 0); }
#[test]
fn test_tl1_gemv_dimension_mismatch_panics() {
let weights = TernaryTensor {
packed_data: vec![0u8; 1],
scales: vec![1.0],
shape: (1, 4),
block_size: BLOCK_SIZE,
};
let result = std::panic::catch_unwind(|| {
let activations = vec![1.0f32; 8]; let mut output = vec![0.0f32; 1];
tl1_gemv(&weights, &activations, &mut output);
});
assert!(result.is_err(), "Should panic on dimension mismatch");
}
#[test]
fn test_tl1_gemv_larger_matrix() {
let out_features = 16;
let in_features = 32;
let ternary_vals: Vec<i8> = (0..out_features * in_features)
.map(|i| if i % 2 == 0 { 1 } else { -1 })
.collect();
let packed = pack_ternary(&ternary_vals);
let scale = 1.0f32;
let weights = TernaryTensor {
packed_data: packed,
scales: vec![scale; (out_features * in_features + BLOCK_SIZE - 1) / BLOCK_SIZE],
shape: (out_features, in_features),
block_size: BLOCK_SIZE,
};
let activations = vec![1.0f32; in_features];
let mut output = vec![0.0f32; out_features];
tl1_gemv(&weights, &activations, &mut output);
for (i, &val) in output.iter().enumerate() {
assert!(
val.abs() < 0.5,
"Row {}: alternating weights with uniform act should be ~0, got {}",
i,
val
);
}
}
}