use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use std::collections::HashMap;
use tracing::debug;
#[derive(Clone, Debug)]
pub struct Linear4Bit {
pub weight_packed: Tensor,
pub scales: Tensor,
pub bias: Option<Tensor>,
pub group_size: usize,
pub in_features: usize,
pub out_features: usize,
}
impl Linear4Bit {
pub fn new(
weight_packed: Tensor,
scales: Tensor,
bias: Option<Tensor>,
group_size: usize,
in_features: usize,
out_features: usize,
) -> Result<Self> {
let packed_dims = weight_packed.dims();
let scales_dims = scales.dims();
if packed_dims.len() != 2
|| packed_dims[0] != out_features
|| packed_dims[1] != in_features / 2
{
return Err(candle_core::Error::Msg(format!(
"Invalid weight_packed shape: expected [{}, {}], got {:?}",
out_features,
in_features / 2,
packed_dims
)));
}
let n_groups = in_features.div_ceil(group_size);
if scales_dims.len() != 2 || scales_dims[0] != out_features || scales_dims[1] != n_groups {
return Err(candle_core::Error::Msg(format!(
"Invalid scales shape: expected [{}, {}], got {:?}",
out_features, n_groups, scales_dims
)));
}
if let Some(ref bias_tensor) = bias {
let bias_dims = bias_tensor.dims();
if bias_dims.len() != 1 || bias_dims[0] != out_features {
return Err(candle_core::Error::Msg(format!(
"Invalid bias shape: expected [{}], got {:?}",
out_features, bias_dims
)));
}
}
Ok(Self {
weight_packed,
scales,
bias,
group_size,
in_features,
out_features,
})
}
pub fn load_4bit(vb: &VarBuilder, prefix: &str) -> Result<Self> {
debug!("Loading 4bit linear layer with prefix: {}", prefix);
let weight_4bit_name = format!("{}.weight_4bit", prefix);
let scales_4bit_name = format!("{}.scales_4bit", prefix);
let bias_name = format!("{}.bias", prefix);
let weight_packed = vb.get_with_hints_dtype(
(), &weight_4bit_name,
candle_nn::init::ZERO,
candle_core::DType::U8,
)?;
let scales = vb.get_with_hints_dtype(
(), &scales_4bit_name,
candle_nn::init::ZERO,
candle_core::DType::F16,
)?;
let bias = match vb.get_with_hints((), &bias_name, candle_nn::init::ZERO) {
Ok(bias_tensor) => Some(bias_tensor),
Err(_) => {
debug!("No bias found for {}, proceeding without bias", prefix);
None
}
};
let packed_dims = weight_packed.dims();
let scales_dims = scales.dims();
if packed_dims.len() != 2 || scales_dims.len() != 2 {
return Err(candle_core::Error::Msg(format!(
"Invalid tensor dimensions: weight_packed {:?}, scales {:?}",
packed_dims, scales_dims
)));
}
let out_features = packed_dims[0];
let in_features = packed_dims[1] * 2; let n_groups = scales_dims[1];
let group_size = in_features.div_ceil(n_groups);
debug!(
"Loaded 4bit layer: {}x{}, group_size={}, n_groups={}",
out_features, in_features, group_size, n_groups
);
Self::new(
weight_packed,
scales,
bias,
group_size,
in_features,
out_features,
)
}
pub fn load_direct(
tensors: &HashMap<String, Tensor>,
prefix: &str,
in_dim: usize,
out_dim: usize,
_group_size: usize,
_symmetric: bool,
device: &candle_core::Device,
) -> Result<Self> {
let weight_4bit_key = format!("{}.weight_4bit", prefix);
let scales_4bit_key = format!("{}.scales_4bit", prefix);
let bias_key = format!("{}.bias", prefix);
if let (Some(weight_packed), Some(scales)) =
(tensors.get(&weight_4bit_key), tensors.get(&scales_4bit_key))
{
debug!("Loading 4bit format for {}", prefix);
let weight_packed = weight_packed.to_device(device)?;
let scales = scales.to_device(device)?;
let bias = tensors
.get(&bias_key)
.map(|t| t.to_device(device))
.transpose()?;
let scales_dims = scales.dims();
let n_groups = scales_dims[1];
let group_size = in_dim.div_ceil(n_groups);
return Self::new(weight_packed, scales, bias, group_size, in_dim, out_dim);
}
Err(candle_core::Error::Msg(format!(
"No 4-bit weights found for prefix: {}",
prefix
)))
}
fn unpack_4bit(&self) -> Result<Tensor> {
let device = self.weight_packed.device();
let packed_data = self.weight_packed.to_vec2::<u8>()?;
let scales_data = self
.scales
.to_dtype(candle_core::DType::F32)?
.to_vec2::<f32>()?;
let mut unpacked = Vec::with_capacity(self.out_features * self.in_features);
for (packed_row, scales_row) in packed_data.iter().zip(scales_data.iter()) {
let mut row = Vec::with_capacity(self.in_features);
for (in_idx, &packed_byte) in packed_row.iter().enumerate() {
let w1 = (packed_byte & 0x0F) as i8;
let w2 = ((packed_byte >> 4) & 0x0F) as i8;
let w1_signed = w1 - 8;
let w2_signed = w2 - 8;
let pos1 = in_idx * 2;
let pos2 = in_idx * 2 + 1;
if pos1 < self.in_features {
let group_idx1 = pos1 / self.group_size;
let scale1 = scales_row[group_idx1];
let dequant1 = w1_signed as f32 * scale1;
row.push(dequant1);
}
if pos2 < self.in_features {
let group_idx2 = pos2 / self.group_size;
let scale2 = scales_row[group_idx2];
let dequant2 = w2_signed as f32 * scale2;
row.push(dequant2);
}
}
while row.len() < self.in_features {
row.push(0.0);
}
row.truncate(self.in_features);
unpacked.extend(row);
}
Tensor::from_vec(unpacked, (self.out_features, self.in_features), device)
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let (input_2d, original_shape) = if input.rank() > 2 {
let dims = input.dims();
let last_dim = dims[dims.len() - 1];
let batch_size = input.elem_count() / last_dim;
(input.reshape(&[batch_size, last_dim])?, Some(dims.to_vec()))
} else {
(input.clone(), None)
};
let output = crate::kernels::matmul_4bit::gemm_4bit(
&input_2d,
&self.weight_packed,
&self.scales,
self.group_size,
)?;
let output = match &self.bias {
Some(bias) => output.broadcast_add(bias)?,
None => output,
};
if let Some(mut dims) = original_shape {
let last_idx = dims.len() - 1;
dims[last_idx] = self.out_features;
output.reshape(&dims[..])
} else {
Ok(output)
}
}
pub fn forward_unpack(&self, input: &Tensor) -> Result<Tensor> {
let (input_2d, original_shape) = if input.rank() > 2 {
let dims = input.dims();
let last_dim = dims[dims.len() - 1];
let batch_size = input.elem_count() / last_dim;
(input.reshape(&[batch_size, last_dim])?, Some(dims.to_vec()))
} else {
(input.clone(), None)
};
let weight = self.unpack_4bit()?;
let output = input_2d.matmul(&weight.t()?)?;
let output = match &self.bias {
Some(bias) => output.broadcast_add(bias)?,
None => output,
};
if let Some(mut dims) = original_shape {
let last_idx = dims.len() - 1;
dims[last_idx] = self.out_features;
output.reshape(&dims[..])
} else {
Ok(output)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{Device, Tensor};
fn create_test_layer() -> Result<Linear4Bit> {
let device = Device::Cpu;
let out_features = 4usize;
let in_features = 8usize;
let group_size = 4usize;
let n_groups = in_features.div_ceil(group_size);
let weight_packed_data: Vec<u8> = vec![
0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, ];
let weight_packed =
Tensor::from_vec(weight_packed_data, (out_features, in_features / 2), &device)?;
let scales_data = vec![
[0.1f32, 0.2f32], [0.15f32, 0.25f32], [0.12f32, 0.22f32], [0.18f32, 0.28f32], ];
let scales = Tensor::from_vec(
scales_data.into_iter().flatten().collect::<Vec<_>>(),
(out_features, n_groups),
&device,
)?
.to_dtype(candle_core::DType::F16)?;
let bias_data = vec![0.1f32, 0.2f32, 0.3f32, 0.4f32];
let bias = Some(Tensor::from_vec(bias_data, out_features, &device)?);
Linear4Bit::new(
weight_packed,
scales,
bias,
group_size,
in_features,
out_features,
)
}
#[test]
fn test_linear_4bit_creation() -> Result<()> {
let layer = create_test_layer()?;
assert_eq!(layer.in_features, 8);
assert_eq!(layer.out_features, 4);
assert_eq!(layer.group_size, 4);
assert!(layer.bias.is_some());
Ok(())
}
#[test]
fn test_linear_4bit_unpack() -> Result<()> {
let layer = create_test_layer()?;
let unpacked = layer.unpack_4bit()?;
let dims = unpacked.dims();
assert_eq!(dims, [4, 8]); Ok(())
}
#[test]
fn test_linear_4bit_forward() -> Result<()> {
let layer = create_test_layer()?;
let device = Device::Cpu;
let input_2d = Tensor::randn(0.0f32, 1.0, (2, 8), &device)?;
let output_2d = layer.forward(&input_2d)?;
assert_eq!(output_2d.dims(), [2, 4]);
let input_3d = Tensor::randn(0.0f32, 1.0, (3, 5, 8), &device)?;
let output_3d = layer.forward(&input_3d)?;
assert_eq!(output_3d.dims(), [3, 5, 4]);
Ok(())
}
#[test]
fn test_4bit_weight_unpacking() -> Result<()> {
let device = Device::Cpu;
let packed_data = vec![0x12u8];
let weight_packed = Tensor::from_vec(packed_data, (1, 1), &device)?;
let scales_data = vec![1.0f32, 1.0f32]; let scales =
Tensor::from_vec(scales_data, (1, 2), &device)?.to_dtype(candle_core::DType::F16)?;
let layer = Linear4Bit::new(
weight_packed,
scales,
None,
1, 2, 1, )?;
let unpacked = layer.unpack_4bit()?;
let data = unpacked.to_vec2::<f32>()?;
assert_eq!(data[0][0], -6.0);
assert_eq!(data[0][1], -7.0);
Ok(())
}
}