use super::config::TernaryConfig;
use super::types::TernaryTensor;
use crate::error::{Result, UnslothError};
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
use candle_core::Device;
use candle_core::Tensor;
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
use cubecl::prelude::*;
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
use cubecl_cuda::CudaRuntime;
#[derive(Clone, Copy, Debug)]
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
pub struct TernaryMatmulConfig {
pub tile_m: u32,
pub tile_n: u32,
pub k_words: u32,
pub m: u32,
pub n: u32,
pub skip_empty_planes: bool,
}
#[allow(unused_variables)]
pub fn ternary_matmul(
input: &Tensor,
weights: &TernaryTensor,
config: &TernaryConfig,
) -> Result<Tensor> {
let input_shape = input.shape().dims();
let (_out_features, in_features) = weights.dims();
let input_features = *input_shape
.last()
.ok_or_else(|| UnslothError::ShapeMismatch {
expected: vec![in_features],
actual: input_shape.to_vec(),
})?;
if input_features != in_features {
return Err(UnslothError::ShapeMismatch {
expected: vec![in_features],
actual: vec![input_features],
});
}
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
{
if let Device::Cuda(_) = input.device() {
if weights.is_sparse_enough(config) {
return ternary_matmul_cuda(input, weights, config);
}
}
}
ternary_matmul_cpu(input, weights)
}
pub fn ternary_matmul_cpu(input: &Tensor, weights: &TernaryTensor) -> Result<Tensor> {
let input_shape = input.shape().dims();
let (out_features, in_features) = weights.dims();
let input_2d = if input_shape.len() == 2 {
input.clone()
} else {
let batch_total: usize = input_shape[..input_shape.len() - 1].iter().product();
input.reshape((batch_total, in_features))?
};
let batch_total = input_2d.shape().dims()[0];
let input_data: Vec<f32> = input_2d.flatten_all()?.to_vec1()?;
let mut output_data = vec![0.0f32; batch_total * out_features];
for b in 0..batch_total {
let input_row = &input_data[b * in_features..(b + 1) * in_features];
for o in 0..out_features {
let planes = weights.get_row_planes(o);
let scale = weights.scales[o];
let mut acc = 0.0f32;
for (i, &val) in input_row.iter().enumerate() {
let ternary_val = planes.get(i);
acc += val * f32::from(ternary_val);
}
output_data[b * out_features + o] = acc * scale;
}
}
let output_shape: Vec<usize> = input_shape[..input_shape.len() - 1]
.iter()
.copied()
.chain(std::iter::once(out_features))
.collect();
let output = Tensor::from_vec(output_data, output_shape.as_slice(), input.device())?;
Ok(output)
}
pub fn ternary_matmul_cpu_packed(
input: &Tensor,
weights: &TernaryTensor,
input_threshold: f32,
) -> Result<Tensor> {
let input_shape = input.shape().dims();
let (out_features, in_features) = weights.dims();
let k_words = weights.k_words;
let input_2d = if input_shape.len() == 2 {
input.clone()
} else {
let batch_total: usize = input_shape[..input_shape.len() - 1].iter().product();
input.reshape((batch_total, in_features))?
};
let batch_total = input_2d.shape().dims()[0];
let input_data: Vec<f32> = input_2d.flatten_all()?.to_vec1()?;
let mut output_data = vec![0.0f32; batch_total * out_features];
for b in 0..batch_total {
let input_row = &input_data[b * in_features..(b + 1) * in_features];
let (input_plus, input_minus, input_scale) =
quantize_activation_row(input_row, input_threshold, k_words);
for o in 0..out_features {
let weight_scale = weights.scales[o];
let plane_offset = o * k_words;
let mut pos_matches = 0i32;
let mut neg_matches = 0i32;
for k in 0..k_words {
let wp = weights.plus_plane[plane_offset + k];
let wm = weights.minus_plane[plane_offset + k];
let ip = input_plus[k];
let im = input_minus[k];
pos_matches += (wp & ip).count_ones().cast_signed();
pos_matches += (wm & im).count_ones().cast_signed();
neg_matches += (wp & im).count_ones().cast_signed();
neg_matches += (wm & ip).count_ones().cast_signed();
}
let dot = pos_matches - neg_matches;
#[allow(clippy::cast_precision_loss)]
{
output_data[b * out_features + o] = dot as f32 * weight_scale * input_scale;
}
}
}
let output_shape: Vec<usize> = input_shape[..input_shape.len() - 1]
.iter()
.copied()
.chain(std::iter::once(out_features))
.collect();
let output = Tensor::from_vec(output_data, output_shape.as_slice(), input.device())?;
Ok(output)
}
fn quantize_activation_row(
data: &[f32],
threshold: f32,
k_words: usize,
) -> (Vec<u32>, Vec<u32>, f32) {
let mut plus = vec![0u32; k_words];
let mut minus = vec![0u32; k_words];
let mut pos_sum = 0.0f64;
let mut neg_sum = 0.0f64;
let mut nonzero_count = 0;
for (i, &val) in data.iter().enumerate() {
let word_idx = i / 32;
let bit_idx = i % 32;
let mask = 1u32 << bit_idx;
if val > threshold {
plus[word_idx] |= mask;
pos_sum += f64::from(val.abs());
nonzero_count += 1;
} else if val < -threshold {
minus[word_idx] |= mask;
neg_sum += f64::from(val.abs());
nonzero_count += 1;
}
}
let scale = if nonzero_count > 0 {
#[allow(clippy::cast_possible_truncation)]
{
((pos_sum + neg_sum) / f64::from(nonzero_count)) as f32
}
} else {
1.0
};
(plus, minus, scale)
}
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
fn ternary_matmul_cuda(
input: &Tensor,
weights: &TernaryTensor,
config: &TernaryConfig,
) -> Result<Tensor> {
use super::matmul_cubecl::*;
use crate::kernels::cubecl::interop::*;
let input_shape = input.shape().dims();
let (out_features, in_features) = weights.dims();
let batch_size = input_shape[0];
let k_words = weights.k_words as u32;
let sparsity = weights.sparsity();
let use_sparse_kernel = sparsity >= 0.90;
let device = input.device();
let gpu_name = detect_gpu_name_placeholder();
log::debug!(
"CUDA ternary matmul: batch={}, out={}, in={}, sparsity={:.2}, kernel={}",
batch_size,
out_features,
in_features,
sparsity,
if use_sparse_kernel {
"sparse"
} else {
"vectorized"
}
);
log::debug!("CUDA ternary matmul: falling back to CPU (GPU dispatch under development)");
ternary_matmul_cpu(input, weights)
}
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
fn detect_gpu_name_placeholder() -> String {
"RTX 3090 Ti".to_string()
}
#[cfg(all(feature = "cuda", feature = "_ternary_cubecl_todo"))]
#[cube(launch_unchecked)]
fn ternary_matmul_kernel<F: Float>(
input: &Array<F>,
w_plus: &Array<u32>,
w_minus: &Array<u32>,
scales: &Array<F>,
output: &mut Array<F>,
#[comptime] config: TernaryMatmulConfig,
) {
let batch_idx = CUBE_POS_X;
let out_idx = CUBE_POS_Y * config.tile_n + UNIT_POS_X;
if out_idx >= config.n {
return;
}
let mut acc = F::new(0.0);
let input_offset = batch_idx * config.k_words * 32;
let weight_offset = out_idx * config.k_words;
for k in 0..config.k_words {
let wp = w_plus[weight_offset + k];
let wm = w_minus[weight_offset + k];
for bit in 0u32..32u32 {
let dim_idx = k * 32 + bit;
if dim_idx < config.k_words * 32 {
let mask = 1u32 << bit;
let is_pos = (wp & mask) != 0;
let is_neg = (wm & mask) != 0;
let input_val = input[input_offset + dim_idx];
if is_pos {
acc = acc + input_val;
} else if is_neg {
acc = acc - input_val;
}
}
}
}
let scale = scales[out_idx];
output[batch_idx * config.n + out_idx] = acc * scale;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernels::ternary::{quantize_tensor, TernaryConfig};
use candle_core::Device;
#[test]
fn test_ternary_matmul_identity() -> Result<()> {
let weight_data = vec![
1.0f32, 0.0, -1.0, 1.0, 0.0, 1.0, 1.0, 0.0, -1.0, -1.0, 0.0, 1.0, ];
let weights = Tensor::from_vec(weight_data, (3, 4), &Device::Cpu)?;
let config = TernaryConfig {
calibration_method: super::super::config::CalibrationMethodConfig::Manual(0.1),
..Default::default()
};
let (ternary_weights, _) = quantize_tensor(&weights, &config)?;
let input_data = vec![
1.0f32, 2.0, 3.0, 4.0, 0.5, 1.5, 2.5, 3.5, ];
let input = Tensor::from_vec(input_data, (2, 4), &Device::Cpu)?;
let output = ternary_matmul_cpu(&input, &ternary_weights)?;
assert_eq!(output.shape().dims(), &[2, 3]);
let output_data: Vec<f32> = output.flatten_all()?.to_vec1()?;
assert!(output_data[0] > 0.0); assert!(output_data[1] > 0.0); assert!(output_data[2] > 0.0);
Ok(())
}
#[test]
fn test_matmul_shape_3d() -> Result<()> {
let weight_data = vec![0.5f32; 64 * 128];
let weights = Tensor::from_vec(weight_data, (64, 128), &Device::Cpu)?;
let config = TernaryConfig::default();
let (ternary_weights, _) = quantize_tensor(&weights, &config)?;
let input = Tensor::zeros((2, 16, 128), candle_core::DType::F32, &Device::Cpu)?;
let output = ternary_matmul_cpu(&input, &ternary_weights)?;
assert_eq!(output.shape().dims(), &[2, 16, 64]);
Ok(())
}
#[test]
fn test_packed_matmul_equivalence() -> Result<()> {
let weight_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 128.0).collect();
let weights = Tensor::from_vec(weight_data, (4, 64), &Device::Cpu)?;
let config = TernaryConfig::default();
let (ternary_weights, _) = quantize_tensor(&weights, &config)?;
let input_data: Vec<f32> = (0..128)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
{
(i as f32) / 64.0 - 1.0
}
})
.collect();
let input = Tensor::from_vec(input_data, (2, 64), &Device::Cpu)?;
let output_std = ternary_matmul_cpu(&input, &ternary_weights)?;
let output_packed = ternary_matmul_cpu_packed(&input, &ternary_weights, 0.3)?;
let std_data: Vec<f32> = output_std.flatten_all()?.to_vec1()?;
let packed_data: Vec<f32> = output_packed.flatten_all()?.to_vec1()?;
let mean_std: f32 = std_data.iter().sum::<f32>() / {
#[allow(clippy::cast_precision_loss)]
{
std_data.len() as f32
}
};
let mean_packed: f32 = packed_data.iter().sum::<f32>() / {
#[allow(clippy::cast_precision_loss)]
{
packed_data.len() as f32
}
};
assert!(
(mean_std - mean_packed).abs() < 1.0,
"Means too different: std={}, packed={}",
mean_std,
mean_packed
);
Ok(())
}
#[test]
fn test_ternary_matmul_dispatch_cpu() -> Result<()> {
let device = Device::Cpu;
let weight_data = vec![1.0f32, -1.0, 0.0, 1.0];
let weights_fp = Tensor::from_vec(weight_data, (2, 2), &device)?;
let config = TernaryConfig {
calibration_method: super::super::config::CalibrationMethodConfig::Manual(0.1),
..Default::default()
};
let (ternary_weights, _) = quantize_tensor(&weights_fp, &config)?;
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
let input = Tensor::from_vec(input_data, (2, 2), &device)?;
let output = ternary_matmul(&input, &ternary_weights, &config)?;
assert_eq!(output.shape().dims(), &[2, 2]);
Ok(())
}
#[test]
#[cfg(feature = "cuda")]
fn test_ternary_matmul_dispatch_gpu() -> Result<()> {
if let Ok(device) = Device::cuda_if_available(0) {
if !matches!(device, Device::Cuda(_)) {
return Ok(()); }
let weight_data = vec![1.0f32, -1.0, 0.0, 1.0];
let weights_fp = Tensor::from_vec(weight_data, (2, 2), &Device::Cpu)?;
let config = TernaryConfig {
calibration_method: super::super::config::CalibrationMethodConfig::Manual(0.1),
..Default::default()
};
let (ternary_weights, _) = quantize_tensor(&weights_fp, &config)?;
let input_data = vec![1.0f32, 2.0, 3.0, 4.0];
let input = Tensor::from_vec(input_data, (2, 2), &device)?;
let output = ternary_matmul(&input, &ternary_weights, &config)?;
assert_eq!(output.shape().dims(), &[2, 2]);
}
Ok(())
}
#[test]
fn test_gpu_name_detection() {
}
}