use super::config::TernaryConfig;
use super::matmul::ternary_matmul;
use super::types::TernaryTensor;
use crate::error::{Result, UnslothError};
use candle_core::{Module, Tensor};
#[derive(Debug, Clone)]
pub struct TernaryLinear {
weights: TernaryTensor,
bias: Option<Tensor>,
config: TernaryConfig,
}
impl TernaryLinear {
pub fn new(weights: TernaryTensor, bias: Option<Tensor>) -> Result<Self> {
Self::with_config(weights, bias, TernaryConfig::default())
}
pub fn with_config(
weights: TernaryTensor,
bias: Option<Tensor>,
config: TernaryConfig,
) -> Result<Self> {
if let Some(ref b) = bias {
let bias_shape = b.shape().dims();
if bias_shape.len() != 1 || bias_shape[0] != weights.dims().0 {
return Err(UnslothError::ShapeMismatch {
expected: vec![weights.dims().0],
actual: bias_shape.to_vec(),
});
}
}
Ok(Self {
weights,
bias,
config,
})
}
#[must_use]
pub fn dims(&self) -> (usize, usize) {
self.weights.dims()
}
#[must_use]
pub fn in_features(&self) -> usize {
self.weights.dims().1
}
#[must_use]
pub fn out_features(&self) -> usize {
self.weights.dims().0
}
#[must_use]
pub fn sparsity(&self) -> f32 {
self.weights.sparsity()
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
self.weights.compression_ratio()
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
let weight_bytes = self.weights.memory_bytes();
let bias_bytes = self.bias.as_ref().map_or(0, |b| b.elem_count() * 4);
weight_bytes + bias_bytes
}
#[must_use]
pub fn is_sparse_enough(&self) -> bool {
self.weights.is_sparse_enough(&self.config)
}
#[must_use]
pub fn weights(&self) -> &TernaryTensor {
&self.weights
}
#[must_use]
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
let mut output = ternary_matmul(input, &self.weights, &self.config)?;
if let Some(ref bias) = self.bias {
output = output.broadcast_add(bias)?;
}
Ok(output)
}
}
impl Module for TernaryLinear {
fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
Self::forward(self, input).map_err(|e| candle_core::Error::Msg(e.to_string()))
}
}
#[derive(Debug, Clone)]
pub struct TernaryLinearBuilder {
config: TernaryConfig,
build_sparsity_metadata: bool,
}
impl Default for TernaryLinearBuilder {
fn default() -> Self {
Self::new()
}
}
impl TernaryLinearBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: TernaryConfig::default(),
build_sparsity_metadata: true,
}
}
#[must_use]
pub fn config(mut self, config: TernaryConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn with_sparsity_metadata(mut self, enable: bool) -> Self {
self.build_sparsity_metadata = enable;
self
}
pub fn build(self, weights: &Tensor, bias: Option<Tensor>) -> Result<TernaryLinear> {
use super::quantize::quantize_tensor;
let (mut ternary_weights, _stats) = quantize_tensor(weights, &self.config)?;
if self.build_sparsity_metadata && self.config.enable_dim_metadata {
ternary_weights.build_sparsity_metadata(self.config.metadata_chunk_size as usize);
}
TernaryLinear::with_config(ternary_weights, bias, self.config)
}
pub fn build_from_linear(self, linear: &candle_nn::Linear) -> Result<TernaryLinear> {
let bias = linear.bias().cloned();
self.build(linear.weight(), bias)
}
}
pub fn convert_linear(linear: &candle_nn::Linear) -> Result<TernaryLinear> {
TernaryLinearBuilder::new().build_from_linear(linear)
}
pub fn convert_linear_with_config(
linear: &candle_nn::Linear,
config: TernaryConfig,
) -> Result<TernaryLinear> {
TernaryLinearBuilder::new()
.config(config)
.build_from_linear(linear)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_ternary_linear_basic() -> Result<()> {
let shape = (4, 8);
let k_words = 1;
let plus = vec![0b00010001u32; 4]; let minus = vec![0b00100010u32; 4]; let scales = vec![1.0f32; 4];
let ternary_weights = super::super::types::TernaryTensor::new(plus, minus, scales, shape);
let layer = TernaryLinear::new(ternary_weights, None)?;
assert_eq!(layer.in_features(), 8);
assert_eq!(layer.out_features(), 4);
let input = Tensor::ones((2, 8), candle_core::DType::F32, &Device::Cpu)?;
let output = layer.forward(&input)?;
assert_eq!(output.shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_ternary_linear_with_bias() -> Result<()> {
use super::super::quantize::quantize_tensor;
let weight_data = vec![0.5f32; 16];
let weights = Tensor::from_vec(weight_data, (4, 4), &Device::Cpu)?;
let config = TernaryConfig::default();
let (ternary_weights, _) = quantize_tensor(&weights, &config)?;
let bias_data = vec![1.0f32, 2.0, 3.0, 4.0];
let bias = Tensor::from_vec(bias_data, 4, &Device::Cpu)?;
let layer = TernaryLinear::new(ternary_weights, Some(bias))?;
let input = Tensor::zeros((1, 4), candle_core::DType::F32, &Device::Cpu)?;
let output = layer.forward(&input)?;
let output_data: Vec<f32> = output.flatten_all()?.to_vec1()?;
assert!((output_data[0] - 1.0).abs() < 0.1);
assert!((output_data[1] - 2.0).abs() < 0.1);
Ok(())
}
#[test]
fn test_builder_pattern() -> Result<()> {
let weight_data: Vec<f32> = (0..4096)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
{
(i as f32 - 2048.0) / 2048.0
}
})
.collect();
let weights = Tensor::from_vec(weight_data, (64, 64), &Device::Cpu)?;
let layer = TernaryLinearBuilder::new()
.config(TernaryConfig::for_sparse_model())
.with_sparsity_metadata(true)
.build(&weights, None)?;
assert_eq!(layer.dims(), (64, 64));
assert!(
layer.compression_ratio() > 5.0,
"Got ratio: {}",
layer.compression_ratio()
);
Ok(())
}
#[test]
fn test_module_trait() -> Result<()> {
use candle_core::Module;
let weight_data = vec![1.0f32; 16];
let weights = Tensor::from_vec(weight_data, (4, 4), &Device::Cpu)?;
let layer = TernaryLinearBuilder::new().build(&weights, None)?;
let input = Tensor::ones((2, 4), candle_core::DType::F32, &Device::Cpu)?;
let output = Module::forward(&layer, &input)?;
assert_eq!(output.shape().dims(), &[2, 4]);
Ok(())
}
#[test]
fn test_memory_efficiency() -> Result<()> {
let weight_data = vec![0.1f32; 4096 * 4096];
let weights = Tensor::from_vec(weight_data, (4096, 4096), &Device::Cpu)?;
let layer = TernaryLinearBuilder::new()
.with_sparsity_metadata(false) .build(&weights, None)?;
let ratio = layer.compression_ratio();
assert!(
ratio > 10.0,
"Compression ratio should be >10x, got {ratio}"
);
Ok(())
}
}