#![cfg(feature = "cuda")]
use candle_core::{DType, Device, Tensor};
use cubecl::prelude::*;
use crate::config::BitNetConfig;
use crate::error::BitNetError;
use crate::quantization::TernaryWeight;
#[allow(dead_code)]
const BLOCK_SIZE: u32 = 256;
const TILE_SIZE: u32 = 32;
const MAX_SHARED_ELEMENTS: u32 = 1024;
#[cube(launch)]
fn absmean_quantize_kernel<F: Float>(
weight: &Array<F>, quantized: &mut Array<i32>, scales: &mut Array<F>, in_features: u32,
) {
let row = CUBE_POS_X;
let tid = UNIT_POS_X;
let block_size = CUBE_DIM_X;
let row_start = row * in_features;
let mut shared_sum = SharedMemory::<F>::new(256usize);
let mut local_sum = F::new(0.0);
let mut i = tid;
while i < in_features {
let val = weight[(row_start + i) as usize];
let abs_val = select(val < F::new(0.0), F::new(0.0) - val, val);
local_sum = local_sum + abs_val;
i = i + block_size;
}
shared_sum[tid as usize] = local_sum;
sync_cube();
let mut stride: u32 = 128;
while stride > 0 {
if tid < stride && (tid + stride) < block_size {
shared_sum[tid as usize] =
shared_sum[tid as usize] + shared_sum[(tid + stride) as usize];
}
sync_cube();
stride = stride / 2;
}
let scale = shared_sum[0usize] / F::cast_from(in_features);
if tid == 0 {
scales[row as usize] = scale;
}
let eps = F::new(1e-8);
let inv_scale = F::new(1.0) / (scale + eps);
sync_cube();
i = tid;
while i < in_features {
let val = weight[(row_start + i) as usize] * inv_scale;
let rounded = F::floor(val + F::new(0.5));
let clamped: i32 = select(
rounded < F::new(-0.5),
-1,
select(rounded > F::new(0.5), 1, 0),
);
quantized[(row_start + i) as usize] = clamped;
i = i + block_size;
}
}
#[cube(launch)]
fn ternary_dequantize_kernel<F: Float>(
ternary: &Array<i32>,
scales: &Array<F>,
output: &mut Array<F>,
in_features: u32,
num_elements: u32,
) {
let idx = ABSOLUTE_POS as u32;
if idx >= num_elements {
terminate!();
}
let row = idx / in_features;
let scale = scales[row as usize];
let trit = ternary[idx as usize];
let trit_f: F = select(
trit == 1,
F::new(1.0),
select(trit == -1, F::new(-1.0), F::new(0.0)),
);
output[idx as usize] = trit_f * scale;
}
#[cube(launch)]
fn ternary_matmul_kernel<F: Float>(
input: &Array<F>, weights: &Array<i32>, scales: &Array<F>, output: &mut Array<F>, batch_size: u32,
in_features: u32,
out_features: u32,
) {
let batch_idx = CUBE_POS_Z;
let out_tile = CUBE_POS_X;
let out_local = UNIT_POS_X;
let out_idx = out_tile * TILE_SIZE + out_local;
if batch_idx >= batch_size || out_idx >= out_features {
terminate!();
}
let input_base = batch_idx * in_features;
let weight_base = out_idx * in_features;
let mut input_tile = SharedMemory::<F>::new(TILE_SIZE as usize);
let mut acc = F::new(0.0);
let num_tiles = (in_features + TILE_SIZE - 1) / TILE_SIZE;
for tile in 0..num_tiles {
let tile_start = tile * TILE_SIZE;
let in_idx = tile_start + out_local;
if in_idx < in_features {
input_tile[out_local as usize] = input[(input_base + in_idx) as usize];
} else {
input_tile[out_local as usize] = F::new(0.0);
}
sync_cube();
for i in 0u32..TILE_SIZE {
let global_in_idx = tile_start + i;
if global_in_idx < in_features {
let trit = weights[(weight_base + global_in_idx) as usize];
let x = input_tile[i as usize];
acc = select(trit == 1, acc + x, select(trit == -1, acc - x, acc));
}
}
sync_cube();
}
let scale = scales[out_idx as usize];
output[(batch_idx * out_features + out_idx) as usize] = acc * scale;
}
#[cube(launch)]
fn packed_ternary_matmul_kernel<F: Float>(
input: &Array<F>,
packed_weights: &Array<u32>, scales: &Array<F>,
output: &mut Array<F>,
batch_size: u32,
in_features: u32,
out_features: u32,
) {
let batch_idx = CUBE_POS_Z;
let out_tile = CUBE_POS_X;
let out_local = UNIT_POS_X;
let out_idx = out_tile * TILE_SIZE + out_local;
if batch_idx >= batch_size || out_idx >= out_features {
terminate!();
}
let input_base = batch_idx * in_features;
let packed_per_row = (in_features + 15) / 16; let weight_base = out_idx * packed_per_row;
let mut acc = F::new(0.0);
for pack_idx in 0u32..packed_per_row {
let packed = packed_weights[(weight_base + pack_idx) as usize];
for i in 0u32..16u32 {
let in_idx = pack_idx * 16 + i;
if in_idx < in_features {
let shift = i * 2;
let trit_bits = (packed >> shift) & 0x3u32;
let x = input[(input_base + in_idx) as usize];
acc = select(
trit_bits == 1u32,
acc + x,
select(trit_bits == 2u32, acc - x, acc),
);
}
}
}
let scale = scales[out_idx as usize];
output[(batch_idx * out_features + out_idx) as usize] = acc * scale;
}
#[cube(launch)]
fn bitlinear_forward_kernel<F: Float>(
input: &Array<F>, weights: &Array<i32>, weight_scales: &Array<F>, ln_weight: &Array<F>, ln_bias: &Array<F>, output: &mut Array<F>,
batch_size: u32,
in_features: u32,
out_features: u32,
) {
let batch_idx = CUBE_POS_Y;
let out_tile = CUBE_POS_X;
let tid = UNIT_POS_X;
let block_size = CUBE_DIM_X;
let out_idx = out_tile * TILE_SIZE + (tid % TILE_SIZE);
if batch_idx >= batch_size || out_idx >= out_features {
terminate!();
}
let input_base = batch_idx * in_features;
let mut shared = SharedMemory::<F>::new(256usize);
let mut normed_cache = SharedMemory::<F>::new(MAX_SHARED_ELEMENTS as usize);
let mut local_sum = F::new(0.0);
let mut i = tid;
while i < in_features {
local_sum = local_sum + input[(input_base + i) as usize];
i = i + block_size;
}
shared[tid as usize] = local_sum;
sync_cube();
let mut stride: u32 = 128;
while stride > 0 {
if tid < stride && (tid + stride) < block_size {
shared[tid as usize] = shared[tid as usize] + shared[(tid + stride) as usize];
}
sync_cube();
stride = stride / 2;
}
let mean = shared[0usize] / F::cast_from(in_features);
sync_cube();
local_sum = F::new(0.0);
i = tid;
while i < in_features {
let diff = input[(input_base + i) as usize] - mean;
local_sum = local_sum + diff * diff;
i = i + block_size;
}
shared[tid as usize] = local_sum;
sync_cube();
stride = 128;
while stride > 0 {
if tid < stride && (tid + stride) < block_size {
shared[tid as usize] = shared[tid as usize] + shared[(tid + stride) as usize];
}
sync_cube();
stride = stride / 2;
}
let var = shared[0usize] / F::cast_from(in_features);
let eps = F::new(1e-5);
let inv_std = F::new(1.0) / F::sqrt(var + eps);
sync_cube();
i = tid;
while i < in_features && i < MAX_SHARED_ELEMENTS {
let norm = (input[(input_base + i) as usize] - mean) * inv_std;
normed_cache[i as usize] = norm * ln_weight[i as usize] + ln_bias[i as usize];
i = i + block_size;
}
sync_cube();
let weight_base = out_idx * in_features;
let mut acc = F::new(0.0);
i = 0;
while i < in_features {
let normed_val: F = if i < MAX_SHARED_ELEMENTS {
normed_cache[i as usize]
} else {
let norm = (input[(input_base + i) as usize] - mean) * inv_std;
norm * ln_weight[i as usize] + ln_bias[i as usize]
};
let trit = weights[(weight_base + i) as usize];
acc = select(
trit == 1,
acc + normed_val,
select(trit == -1, acc - normed_val, acc),
);
i = i + 1;
}
let scale = weight_scales[out_idx as usize];
if tid % TILE_SIZE == out_idx % TILE_SIZE {
output[(batch_idx * out_features + out_idx) as usize] = acc * scale;
}
}
#[must_use]
pub fn has_cuda_support() -> bool {
matches!(Device::cuda_if_available(0), Ok(Device::Cuda(_)))
}
pub fn absmean_quantize(weight: &Tensor) -> std::result::Result<(Tensor, Tensor), BitNetError> {
if !weight.device().is_cuda() {
return Err(BitNetError::FeatureNotAvailable(
"absmean_quantize requires CUDA device".into(),
));
}
let (out_features, in_features) = weight.dims2()?;
let device = weight.device();
let weight_f32: Vec<f32> = weight.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
let mut quantized = vec![0i32; out_features * in_features];
let mut scales = vec![0.0f32; out_features];
for row in 0..out_features {
let row_start = row * in_features;
let abs_sum: f32 = weight_f32[row_start..row_start + in_features]
.iter()
.map(|x| x.abs())
.sum();
let scale = abs_sum / in_features as f32;
scales[row] = scale;
let inv_scale = if scale > 1e-8 { 1.0 / scale } else { 1.0 };
for i in 0..in_features {
let val = weight_f32[row_start + i] * inv_scale;
let rounded = val.round();
quantized[row_start + i] = rounded.clamp(-1.0, 1.0) as i32;
}
}
let quantized_tensor = Tensor::from_vec(quantized, (out_features, in_features), device)?;
let scales_tensor = Tensor::from_vec(scales, out_features, device)?;
Ok((quantized_tensor, scales_tensor))
}
pub fn ternary_dequantize(
ternary: &Tensor,
scales: &Tensor,
) -> std::result::Result<Tensor, BitNetError> {
let _device = ternary.device();
let (out_features, _in_features) = ternary.dims2()?;
let ternary_f32 = ternary.to_dtype(DType::F32)?;
let scales_broadcast = scales.reshape((out_features, 1))?;
let output = ternary_f32.broadcast_mul(&scales_broadcast)?;
Ok(output)
}
pub fn ternary_matmul_gpu(
input: &Tensor,
weight: &TernaryWeight,
) -> std::result::Result<Tensor, BitNetError> {
let device = input.device();
let dequant_weight = crate::quantization::dequantize_weights(weight, device)?;
let output = input
.matmul(&dequant_weight.t()?)
.map_err(BitNetError::from)?;
Ok(output)
}
pub fn ternary_matmul_raw(
input: &Tensor,
ternary_weights: &Tensor,
scales: &Tensor,
) -> std::result::Result<Tensor, BitNetError> {
let _device = input.device();
let dequant = ternary_dequantize(ternary_weights, scales)?;
let output = input.matmul(&dequant.t()?)?;
Ok(output)
}
pub fn pack_ternary_weights(ternary: &Tensor) -> std::result::Result<Tensor, BitNetError> {
let (out_features, in_features) = ternary.dims2()?;
let device = ternary.device();
let ternary_i32: Vec<i32> = ternary.flatten_all()?.to_vec1()?;
let trits_per_word = 16usize;
let packed_per_row = (in_features + trits_per_word - 1) / trits_per_word;
let mut packed = vec![0u32; out_features * packed_per_row];
for row in 0..out_features {
for pack_idx in 0..packed_per_row {
let mut word = 0u32;
for i in 0..trits_per_word {
let in_idx = pack_idx * trits_per_word + i;
if in_idx < in_features {
let trit = ternary_i32[row * in_features + in_idx];
let bits = match trit {
1 => 0b01u32,
-1 => 0b10u32,
_ => 0b00u32,
};
word |= bits << (i * 2);
}
}
packed[row * packed_per_row + pack_idx] = word;
}
}
let packed_tensor = Tensor::from_vec(packed, (out_features, packed_per_row), device)?;
Ok(packed_tensor)
}
pub fn packed_ternary_matmul(
input: &Tensor,
packed_weights: &Tensor,
scales: &Tensor,
in_features: usize,
) -> std::result::Result<Tensor, BitNetError> {
let _device = input.device();
let _batch_size = input.dims()[0];
let _out_features = packed_weights.dims()[0];
let ternary = unpack_ternary_weights(packed_weights, in_features)?;
let output = ternary_matmul_raw(input, &ternary, scales)?;
Ok(output)
}
pub fn unpack_ternary_weights(
packed: &Tensor,
in_features: usize,
) -> std::result::Result<Tensor, BitNetError> {
let (out_features, packed_per_row) = packed.dims2()?;
let device = packed.device();
let packed_u32: Vec<u32> = packed.to_dtype(DType::U32)?.flatten_all()?.to_vec1()?;
let trits_per_word = 16usize;
let mut ternary = vec![0i32; out_features * in_features];
for row in 0..out_features {
for pack_idx in 0..packed_per_row {
let word = packed_u32[row * packed_per_row + pack_idx];
for i in 0..trits_per_word {
let in_idx = pack_idx * trits_per_word + i;
if in_idx < in_features {
let bits = (word >> (i * 2)) & 0x3;
let trit = match bits {
0b01 => 1i32,
0b10 => -1i32,
_ => 0i32,
};
ternary[row * in_features + in_idx] = trit;
}
}
}
}
let ternary_tensor = Tensor::from_vec(ternary, (out_features, in_features), device)?;
Ok(ternary_tensor)
}
pub fn bitlinear_forward(
input: &Tensor,
weight: &TernaryWeight,
ln_weight: &Tensor,
ln_bias: &Tensor,
config: &BitNetConfig,
) -> std::result::Result<Tensor, BitNetError> {
let _device = input.device();
let eps = config.eps;
let mean = input.mean_keepdim(1)?;
let centered = input.broadcast_sub(&mean)?;
let var = centered.sqr()?.mean_keepdim(1)?;
let std = (var + eps as f64)?.sqrt()?;
let normalized = centered.broadcast_div(&std)?;
let ln_output = normalized
.broadcast_mul(ln_weight)?
.broadcast_add(ln_bias)?;
let output = ternary_matmul_gpu(&ln_output, weight)?;
Ok(output)
}
#[must_use]
pub fn should_use_gpu(input: &Tensor, weight: &TernaryWeight) -> bool {
if !input.device().is_cuda() {
return false;
}
let input_size = input.elem_count();
let weight_size = weight.out_features() * weight.in_features();
input_size * weight_size > 65536
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::BitNetConfig;
use crate::quantization::quantize_weights;
use candle_core::Device;
#[test]
fn test_ternary_matmul_cpu_fallback() {
let device = Device::Cpu;
let config = BitNetConfig::default().with_group_size(64);
let weight_tensor = candle_core::Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let weight = quantize_weights(&weight_tensor, &config).unwrap();
let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
let output = ternary_matmul_gpu(&input, &weight).unwrap();
assert_eq!(output.shape().dims(), &[4, 64]);
}
#[test]
fn test_pack_unpack_roundtrip() {
let device = Device::Cpu;
let ternary_data: Vec<i32> = vec![1, 0, -1, 1, 0, -1, 0, 1, -1, 0, 1, -1, 0, 0, 1, -1];
let ternary = Tensor::from_vec(ternary_data.clone(), (1, 16), &device).unwrap();
let packed = pack_ternary_weights(&ternary).unwrap();
let unpacked = unpack_ternary_weights(&packed, 16).unwrap();
let unpacked_data: Vec<i32> = unpacked.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(ternary_data, unpacked_data);
}
#[test]
fn test_ternary_dequantize() {
let device = Device::Cpu;
let ternary = Tensor::from_vec(vec![1i32, 0, -1, 1], (2, 2), &device).unwrap();
let scales = Tensor::from_vec(vec![2.0f32, 0.5], 2, &device).unwrap();
let output = ternary_dequantize(&ternary, &scales).unwrap();
let output_data: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
assert!((output_data[0] - 2.0).abs() < 1e-6);
assert!((output_data[1] - 0.0).abs() < 1e-6);
assert!((output_data[2] - (-0.5)).abs() < 1e-6);
assert!((output_data[3] - 0.5).abs() < 1e-6);
}
#[test]
fn test_absmean_quantize_cpu() {
let device = Device::Cpu;
let weight = Tensor::from_vec(
vec![1.0f32, -0.5, 0.2, -0.8, 0.1, 0.9, -0.3, 0.0],
(2, 4),
&device,
)
.unwrap();
let result = absmean_quantize(&weight);
assert!(result.is_err()); }
}