use cubecl::prelude::*;
const BITS_PER_U32: u32 = 32;
const CHUNKS_PER_U64: u32 = 64;
#[derive(Clone, Copy, Debug)]
pub struct BasicTernaryMatmulConfig {
pub k_words: u32,
pub m: u32,
pub n: u32,
pub in_features: u32,
}
#[cube(launch_unchecked)]
pub fn ternary_matmul_kernel_basic<F: Float>(
input: &Array<F>,
w_plus: &Array<F>,
w_minus: &Array<F>,
scales: &Array<F>,
output: &mut Array<F>,
#[comptime] config: BasicTernaryMatmulConfig,
) {
let batch_idx = CUBE_POS_X;
let out_idx = CUBE_POS_Y * CUBE_DIM_X + UNIT_POS_X;
if batch_idx >= config.m {
return;
}
if out_idx >= config.n {
return;
}
let input_offset = batch_idx * config.in_features;
let mut input_plus = Array::<u32>::new(config.k_words);
let mut input_minus = Array::<u32>::new(config.k_words);
for k in 0..config.k_words {
let mut plus_word = 0u32;
let mut minus_word = 0u32;
let base_dim = k * 32;
let end_dim = u32::min(base_dim + 32, config.in_features);
for bit in 0u32..(end_dim - base_dim) {
let dim_idx = base_dim + bit;
let val = input[input_offset + dim_idx];
let threshold = F::new(0.5);
let neg_threshold = F::new(-0.5);
if val > threshold {
plus_word = plus_word | (1u32 << bit);
} else if val < neg_threshold {
minus_word = minus_word | (1u32 << bit);
}
}
input_plus[k] = plus_word;
input_minus[k] = minus_word;
}
let weight_offset = out_idx * config.k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
for k in 0..config.k_words {
let wp_f32 = w_plus[weight_offset + k];
let wm_f32 = w_minus[weight_offset + k];
let wp_bits = u32::reinterpret(wp_f32);
let wm_bits = u32::reinterpret(wm_f32);
let ip = input_plus[k];
let im = input_minus[k];
pos_sum = pos_sum + (wp_bits & ip).count_ones();
pos_sum = pos_sum + (wm_bits & im).count_ones();
neg_sum = neg_sum + (wp_bits & im).count_ones();
neg_sum = neg_sum + (wm_bits & ip).count_ones();
}
let dot = F::cast_from(pos_sum) - F::cast_from(neg_sum);
let scale = scales[out_idx];
output[batch_idx * config.n + out_idx] = dot * scale;
}
#[derive(Clone, Copy, Debug)]
pub struct TiledTernaryMatmulConfig {
pub k_words: u32,
pub m: u32,
pub n: u32,
pub in_features: u32,
pub tile_k: u32,
pub outputs_per_thread: u32,
pub block_size: u32,
}
impl TiledTernaryMatmulConfig {
pub fn rtx_5080_preset(m: u32, n: u32, k_words: u32, in_features: u32) -> Self {
Self {
k_words,
m,
n,
in_features,
tile_k: 64, outputs_per_thread: 2,
block_size: 256,
}
}
pub fn rtx_3090ti_preset(m: u32, n: u32, k_words: u32, in_features: u32) -> Self {
Self {
k_words,
m,
n,
in_features,
tile_k: 32, outputs_per_thread: 2,
block_size: 256,
}
}
}
#[cube(launch_unchecked)]
pub fn ternary_matmul_kernel_tiled<F: Float>(
input: &Array<F>,
w_plus: &Array<F>,
w_minus: &Array<F>,
scales: &Array<F>,
output: &mut Array<F>,
#[comptime] config: TiledTernaryMatmulConfig,
) {
let batch_idx = CUBE_POS_X;
let out_block_idx = CUBE_POS_Y;
let thread_idx = UNIT_POS_X;
if batch_idx >= config.m {
return;
}
let mut input_plus_tile = SharedMemory::<u32>::new(config.tile_k);
let mut input_minus_tile = SharedMemory::<u32>::new(config.tile_k);
for out_local in 0..config.outputs_per_thread {
let out_idx = out_block_idx * config.block_size * config.outputs_per_thread
+ thread_idx * config.outputs_per_thread
+ out_local;
if out_idx >= config.n {
continue;
}
let input_offset = batch_idx * config.in_features;
let weight_offset = out_idx * config.k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
let num_k_tiles = (config.k_words + config.tile_k - 1) / config.tile_k;
for k_tile in 0..num_k_tiles {
let k_start = k_tile * config.tile_k;
let k_end = u32::min(k_start + config.tile_k, config.k_words);
let tile_size = k_end - k_start;
if thread_idx < tile_size {
let k_word = k_start + thread_idx;
let mut plus_word = 0u32;
let mut minus_word = 0u32;
let base_dim = k_word * 32;
let end_dim = u32::min(base_dim + 32, config.in_features);
for bit in 0..(end_dim - base_dim) {
let dim_idx = base_dim + bit;
let val = input[input_offset + dim_idx];
let threshold = F::new(0.5);
let neg_threshold = F::new(-0.5);
if val > threshold {
plus_word = plus_word | (1u32 << bit);
} else if val < neg_threshold {
minus_word = minus_word | (1u32 << bit);
}
}
input_plus_tile[thread_idx] = plus_word;
input_minus_tile[thread_idx] = minus_word;
}
sync_units();
for k_local in 0..tile_size {
let k_word = k_start + k_local;
let wp_f32 = w_plus[weight_offset + k_word];
let wm_f32 = w_minus[weight_offset + k_word];
let wp_bits = u32::reinterpret(wp_f32);
let wm_bits = u32::reinterpret(wm_f32);
let ip = input_plus_tile[k_local];
let im = input_minus_tile[k_local];
pos_sum = pos_sum + (wp_bits & ip).count_ones();
pos_sum = pos_sum + (wm_bits & im).count_ones();
neg_sum = neg_sum + (wp_bits & im).count_ones();
neg_sum = neg_sum + (wm_bits & ip).count_ones();
}
sync_units();
}
let dot = F::cast_from(pos_sum) - F::cast_from(neg_sum);
let scale = scales[out_idx];
output[batch_idx * config.n + out_idx] = dot * scale;
}
}
#[derive(Clone, Copy, Debug)]
pub struct VectorizedTernaryMatmulConfig {
pub k_words: u32,
pub m: u32,
pub n: u32,
pub in_features: u32,
pub tile_k_lines: u32,
pub outputs_per_thread: u32,
pub block_size: u32,
}
impl VectorizedTernaryMatmulConfig {
pub fn rtx_5080_preset(m: u32, n: u32, k_words: u32, in_features: u32) -> Self {
Self {
k_words,
m,
n,
in_features,
tile_k_lines: 16, outputs_per_thread: 2,
block_size: 256,
}
}
pub fn rtx_3090ti_preset(m: u32, n: u32, k_words: u32, in_features: u32) -> Self {
Self {
k_words,
m,
n,
in_features,
tile_k_lines: 8, outputs_per_thread: 2,
block_size: 256,
}
}
pub fn tile_k(&self) -> u32 {
self.tile_k_lines * 4
}
}
#[derive(Clone, Copy, Debug)]
pub struct SparseOptimizedConfig {
pub base: VectorizedTernaryMatmulConfig,
pub enable_plane_skipping: bool,
pub sparsity_threshold: f32,
pub chunk_size: u32,
}
impl SparseOptimizedConfig {
pub fn from_sparsity(base: VectorizedTernaryMatmulConfig, sparsity: f32) -> Self {
Self {
base,
enable_plane_skipping: sparsity >= 0.90,
sparsity_threshold: 0.90,
chunk_size: 64, }
}
pub fn rtx_5080_sparse(m: u32, n: u32, k_words: u32, in_features: u32, sparsity: f32) -> Self {
let base = VectorizedTernaryMatmulConfig::rtx_5080_preset(m, n, k_words, in_features);
Self::from_sparsity(base, sparsity)
}
pub fn rtx_3090ti_sparse(
m: u32,
n: u32,
k_words: u32,
in_features: u32,
sparsity: f32,
) -> Self {
let base = VectorizedTernaryMatmulConfig::rtx_3090ti_preset(m, n, k_words, in_features);
Self::from_sparsity(base, sparsity)
}
pub fn words_per_chunk(&self) -> u32 {
self.chunk_size / 32
}
}
#[cube(launch_unchecked)]
pub fn ternary_matmul_kernel_vectorized<F: Float>(
input: &Array<F>,
w_plus: &Array<F>,
w_minus: &Array<F>,
scales: &Array<F>,
output: &mut Array<F>,
#[comptime] config: VectorizedTernaryMatmulConfig,
) {
let batch_idx = CUBE_POS_X;
let out_block_idx = CUBE_POS_Y;
let thread_idx = UNIT_POS_X;
if batch_idx >= config.m {
return;
}
let tile_k_words = config.tile_k();
let mut input_plus_tile = SharedMemory::<u32>::new(tile_k_words);
let mut input_minus_tile = SharedMemory::<u32>::new(tile_k_words);
for out_local in 0..config.outputs_per_thread {
let out_idx = out_block_idx * config.block_size * config.outputs_per_thread
+ thread_idx * config.outputs_per_thread
+ out_local;
if out_idx >= config.n {
continue;
}
let input_offset = batch_idx * config.in_features;
let weight_offset = out_idx * config.k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
let num_k_tiles = (config.k_words + tile_k_words - 1) / tile_k_words;
for k_tile in 0..num_k_tiles {
let k_start = k_tile * tile_k_words;
let k_end = u32::min(k_start + tile_k_words, config.k_words);
let tile_size = k_end - k_start;
let num_vec_loads = tile_size / 4;
let remaining_words = tile_size % 4;
if thread_idx < num_vec_loads {
let vec_idx = thread_idx;
let base_word = k_start + vec_idx * 4;
for word_offset in 0u32..4u32 {
let k_word = base_word + word_offset;
let mut plus_word = 0u32;
let mut minus_word = 0u32;
let base_dim = k_word * 32;
let end_dim = u32::min(base_dim + 32, config.in_features);
for bit in 0..(end_dim - base_dim) {
let dim_idx = base_dim + bit;
let val = input[input_offset + dim_idx];
let threshold = F::new(0.5);
let neg_threshold = F::new(-0.5);
if val > threshold {
plus_word = plus_word | (1u32 << bit);
} else if val < neg_threshold {
minus_word = minus_word | (1u32 << bit);
}
}
let tile_offset = vec_idx * 4 + word_offset;
input_plus_tile[tile_offset] = plus_word;
input_minus_tile[tile_offset] = minus_word;
}
}
if remaining_words > 0 && thread_idx < remaining_words {
let k_word = k_start + num_vec_loads * 4 + thread_idx;
let mut plus_word = 0u32;
let mut minus_word = 0u32;
let base_dim = k_word * 32;
let end_dim = u32::min(base_dim + 32, config.in_features);
for bit in 0..(end_dim - base_dim) {
let dim_idx = base_dim + bit;
let val = input[input_offset + dim_idx];
let threshold = F::new(0.5);
let neg_threshold = F::new(-0.5);
if val > threshold {
plus_word = plus_word | (1u32 << bit);
} else if val < neg_threshold {
minus_word = minus_word | (1u32 << bit);
}
}
let tile_offset = num_vec_loads * 4 + thread_idx;
input_plus_tile[tile_offset] = plus_word;
input_minus_tile[tile_offset] = minus_word;
}
sync_units();
let num_vec_iter = tile_size / 4;
let remaining = tile_size % 4;
for vec_idx in 0..num_vec_iter {
for word_offset in 0u32..4u32 {
let k_word = k_start + vec_idx * 4 + word_offset;
let k_local = vec_idx * 4 + word_offset;
let wp_f32 = w_plus[weight_offset + k_word];
let wm_f32 = w_minus[weight_offset + k_word];
let wp_bits = u32::reinterpret(wp_f32);
let wm_bits = u32::reinterpret(wm_f32);
let ip = input_plus_tile[k_local];
let im = input_minus_tile[k_local];
pos_sum = pos_sum + (wp_bits & ip).count_ones();
pos_sum = pos_sum + (wm_bits & im).count_ones();
neg_sum = neg_sum + (wp_bits & im).count_ones();
neg_sum = neg_sum + (wm_bits & ip).count_ones();
}
}
for word_offset in 0..remaining {
let k_word = k_start + num_vec_iter * 4 + word_offset;
let k_local = num_vec_iter * 4 + word_offset;
let wp_f32 = w_plus[weight_offset + k_word];
let wm_f32 = w_minus[weight_offset + k_word];
let wp_bits = u32::reinterpret(wp_f32);
let wm_bits = u32::reinterpret(wm_f32);
let ip = input_plus_tile[k_local];
let im = input_minus_tile[k_local];
pos_sum = pos_sum + (wp_bits & ip).count_ones();
pos_sum = pos_sum + (wm_bits & im).count_ones();
neg_sum = neg_sum + (wp_bits & im).count_ones();
neg_sum = neg_sum + (wm_bits & ip).count_ones();
}
sync_units();
}
let dot = F::cast_from(pos_sum) - F::cast_from(neg_sum);
let scale = scales[out_idx];
output[batch_idx * config.n + out_idx] = dot * scale;
}
}
pub fn get_vectorized_launch_config(
config: &VectorizedTernaryMatmulConfig,
) -> (CubeCount, CubeDim) {
let outputs_per_block = config.block_size * config.outputs_per_thread;
let grid_y = (config.n + outputs_per_block - 1) / outputs_per_block;
let cube_count = CubeCount::Static(config.m, grid_y, 1);
let cube_dim = CubeDim::new(config.block_size, 1, 1);
(cube_count, cube_dim)
}
#[cube(launch_unchecked)]
pub fn ternary_matmul_kernel_sparse<F: Float>(
input: &Array<F>,
w_plus: &Array<F>,
w_minus: &Array<F>,
scales: &Array<F>,
sparsity_bitmap: &Array<u64>,
output: &mut Array<F>,
#[comptime] config: SparseOptimizedConfig,
) {
let batch_idx = CUBE_POS_X;
let out_block_idx = CUBE_POS_Y;
let thread_idx = UNIT_POS_X;
if batch_idx >= config.base.m {
return;
}
let tile_k_words = config.base.tile_k();
let mut input_plus_tile = SharedMemory::<u32>::new(tile_k_words);
let mut input_minus_tile = SharedMemory::<u32>::new(tile_k_words);
for out_local in 0..config.base.outputs_per_thread {
let out_idx = out_block_idx * config.base.block_size * config.base.outputs_per_thread
+ thread_idx * config.base.outputs_per_thread
+ out_local;
if out_idx >= config.base.n {
continue;
}
let input_offset = batch_idx * config.base.in_features;
let weight_offset = out_idx * config.base.k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
let words_per_chunk = config.words_per_chunk();
let num_chunks = (config.base.k_words + words_per_chunk - 1) / words_per_chunk;
let bitmap_words = (num_chunks + CHUNKS_PER_U64 - 1) / CHUNKS_PER_U64;
let bitmap_offset = out_idx * bitmap_words;
let num_k_tiles = (config.base.k_words + tile_k_words - 1) / tile_k_words;
for k_tile in 0..num_k_tiles {
let k_start = k_tile * tile_k_words;
let k_end = u32::min(k_start + tile_k_words, config.base.k_words);
let tile_size = k_end - k_start;
let num_vec_loads = tile_size / 4;
let remaining_words = tile_size % 4;
if thread_idx < num_vec_loads {
let vec_idx = thread_idx;
let base_word = k_start + vec_idx * 4;
for word_offset in 0u32..4u32 {
let k_word = base_word + word_offset;
let mut plus_word = 0u32;
let mut minus_word = 0u32;
let base_dim = k_word * 32;
let end_dim = u32::min(base_dim + 32, config.base.in_features);
for bit in 0..(end_dim - base_dim) {
let dim_idx = base_dim + bit;
let val = input[input_offset + dim_idx];
let threshold = F::new(0.5);
let neg_threshold = F::new(-0.5);
if val > threshold {
plus_word = plus_word | (1u32 << bit);
} else if val < neg_threshold {
minus_word = minus_word | (1u32 << bit);
}
}
let tile_offset = vec_idx * 4 + word_offset;
input_plus_tile[tile_offset] = plus_word;
input_minus_tile[tile_offset] = minus_word;
}
}
if remaining_words > 0 && thread_idx < remaining_words {
let k_word = k_start + num_vec_loads * 4 + thread_idx;
let mut plus_word = 0u32;
let mut minus_word = 0u32;
let base_dim = k_word * 32;
let end_dim = u32::min(base_dim + 32, config.base.in_features);
for bit in 0..(end_dim - base_dim) {
let dim_idx = base_dim + bit;
let val = input[input_offset + dim_idx];
let threshold = F::new(0.5);
let neg_threshold = F::new(-0.5);
if val > threshold {
plus_word = plus_word | (1u32 << bit);
} else if val < neg_threshold {
minus_word = minus_word | (1u32 << bit);
}
}
let tile_offset = num_vec_loads * 4 + thread_idx;
input_plus_tile[tile_offset] = plus_word;
input_minus_tile[tile_offset] = minus_word;
}
sync_units();
let chunk_start = k_start / words_per_chunk;
let chunk_end = (k_end + words_per_chunk - 1) / words_per_chunk;
for chunk_idx in chunk_start..chunk_end {
if config.enable_plane_skipping {
let bitmap_word_idx = chunk_idx / CHUNKS_PER_U64;
let bit_idx = chunk_idx % CHUNKS_PER_U64;
let bitmap = sparsity_bitmap[bitmap_offset + bitmap_word_idx];
let is_active = (bitmap & (1u64 << bit_idx)) != 0u64;
if !is_active {
continue;
}
}
let chunk_word_start = chunk_idx * words_per_chunk;
let chunk_word_end = u32::min(chunk_word_start + words_per_chunk, k_end);
for k_word in chunk_word_start..chunk_word_end {
if k_word < k_start || k_word >= k_end {
continue;
}
let k_local = k_word - k_start;
let wp_f32 = w_plus[weight_offset + k_word];
let wm_f32 = w_minus[weight_offset + k_word];
let wp_bits = u32::reinterpret(wp_f32);
let wm_bits = u32::reinterpret(wm_f32);
let ip = input_plus_tile[k_local];
let im = input_minus_tile[k_local];
pos_sum = pos_sum + (wp_bits & ip).count_ones();
pos_sum = pos_sum + (wm_bits & im).count_ones();
neg_sum = neg_sum + (wp_bits & im).count_ones();
neg_sum = neg_sum + (wm_bits & ip).count_ones();
}
}
sync_units();
}
let dot = F::cast_from(pos_sum) - F::cast_from(neg_sum);
let scale = scales[out_idx];
output[batch_idx * config.base.n + out_idx] = dot * scale;
}
}
pub fn get_sparse_launch_config(config: &SparseOptimizedConfig) -> (CubeCount, CubeDim) {
get_vectorized_launch_config(&config.base)
}
pub fn get_tiled_launch_config(config: &TiledTernaryMatmulConfig) -> (CubeCount, CubeDim) {
let outputs_per_block = config.block_size * config.outputs_per_thread;
let grid_y = (config.n + outputs_per_block - 1) / outputs_per_block;
let cube_count = CubeCount::Static(config.m, grid_y, 1);
let cube_dim = CubeDim::new(config.block_size, 1, 1);
(cube_count, cube_dim)
}
pub fn get_basic_launch_config(batch_size: u32, out_features: u32) -> (CubeCount, CubeDim) {
let block_size = 256u32;
let grid_y = (out_features + block_size - 1) / block_size;
let cube_count = CubeCount::Static(batch_size, grid_y, 1);
let cube_dim = CubeDim::new(block_size, 1, 1);
(cube_count, cube_dim)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = BasicTernaryMatmulConfig {
k_words: 4,
m: 8,
n: 64,
in_features: 128,
};
assert_eq!(config.k_words, 4);
assert_eq!(config.in_features, 128);
}
#[test]
fn test_launch_config() {
let (cube_count, cube_dim) = get_basic_launch_config(4, 512);
assert_eq!(cube_dim.x, 256);
assert_eq!(cube_dim.y, 1);
assert_eq!(cube_dim.z, 1);
if let CubeCount::Static(x, y, z) = cube_count {
assert_eq!(x, 4, "Grid X dimension should match batch size");
assert_eq!(y, 2, "Grid Y dimension should be ceil(512/256) = 2");
assert_eq!(z, 1, "Grid Z dimension should be 1");
} else {
panic!("Expected Static cube count");
}
}
#[test]
fn test_ternary_matmul_kernel_correctness() {
let batch_size = 2;
let in_features = 64;
let out_features = 3;
let k_words = 2;
let mut input_data = vec![0.0f32; batch_size * in_features];
for b in 0..batch_size {
for i in 0..in_features {
input_data[b * in_features + i] = if b == 0 {
if i % 2 == 0 {
1.0
} else {
-1.0
}
} else {
1.0
};
}
}
let mut input_plus = vec![0u32; batch_size * k_words];
let mut input_minus = vec![0u32; batch_size * k_words];
for b in 0..batch_size {
for k in 0..k_words {
let mut plus_word = 0u32;
let mut minus_word = 0u32;
for bit in 0..32 {
let dim_idx = k * 32 + bit;
let val = input_data[b * in_features + dim_idx];
if val > 0.5 {
plus_word |= 1u32 << bit;
} else if val < -0.5 {
minus_word |= 1u32 << bit;
}
}
input_plus[b * k_words + k] = plus_word;
input_minus[b * k_words + k] = minus_word;
}
}
let mut w_plus = vec![0u32; out_features * k_words];
let mut w_minus = vec![0u32; out_features * k_words];
w_plus[0] = 0xFFFFFFFF;
w_plus[1] = 0xFFFFFFFF;
w_minus[2] = 0xFFFFFFFF;
w_minus[3] = 0xFFFFFFFF;
w_plus[4] = 0xAAAAAAAA;
w_plus[5] = 0xAAAAAAAA;
w_minus[4] = 0x55555555;
w_minus[5] = 0x55555555;
let scales = vec![1.0f32; out_features];
let mut output = vec![0.0f32; batch_size * out_features];
for batch_idx in 0..batch_size {
for out_idx in 0..out_features {
let weight_offset = out_idx * k_words;
let input_offset = batch_idx * k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
for k in 0..k_words {
let wp_bits = w_plus[weight_offset + k];
let wm_bits = w_minus[weight_offset + k];
let ip = input_plus[input_offset + k];
let im = input_minus[input_offset + k];
pos_sum += (wp_bits & ip).count_ones();
pos_sum += (wm_bits & im).count_ones();
neg_sum += (wp_bits & im).count_ones();
neg_sum += (wm_bits & ip).count_ones();
}
let dot = (pos_sum as i32 - neg_sum as i32) as f32;
output[batch_idx * out_features + out_idx] = dot * scales[out_idx];
}
}
let expected_b0_f0 = 0.0;
let expected_b0_f1 = 0.0;
let expected_b0_f2 = 64.0;
let expected_b1_f0 = 64.0;
let expected_b1_f1 = -64.0;
let expected_b1_f2 = 0.0;
assert!(
(output[0] - expected_b0_f0).abs() < 0.01,
"B0F0 mismatch: expected {}, got {}",
expected_b0_f0,
output[0]
);
assert!(
(output[1] - expected_b0_f1).abs() < 0.01,
"B0F1 mismatch: expected {}, got {}",
expected_b0_f1,
output[1]
);
assert!(
(output[2] - expected_b0_f2).abs() < 0.01,
"B0F2 mismatch: expected {}, got {}",
expected_b0_f2,
output[2]
);
assert!(
(output[3] - expected_b1_f0).abs() < 0.01,
"B1F0 mismatch: expected {}, got {}",
expected_b1_f0,
output[3]
);
assert!(
(output[4] - expected_b1_f1).abs() < 0.01,
"B1F1 mismatch: expected {}, got {}",
expected_b1_f1,
output[4]
);
assert!(
(output[5] - expected_b1_f2).abs() < 0.01,
"B1F2 mismatch: expected {}, got {}",
expected_b1_f2,
output[5]
);
}
#[test]
fn test_ternary_matmul_partial_word() {
let batch_size = 2;
let in_features = 48; let out_features = 2;
let k_words = 2;
let mut input_data = vec![0.0f32; batch_size * in_features];
for b in 0..batch_size {
for i in 0..in_features {
input_data[b * in_features + i] = if b == 0 { 1.0 } else { -1.0 };
}
}
let mut input_plus = vec![0u32; batch_size * k_words];
let mut input_minus = vec![0u32; batch_size * k_words];
for b in 0..batch_size {
for k in 0..k_words {
let mut plus_word = 0u32;
let mut minus_word = 0u32;
let base_dim = k * 32;
let end_dim = std::cmp::min(base_dim + 32, in_features);
for bit in 0..(end_dim - base_dim) {
let dim_idx = base_dim + bit;
let val = input_data[b * in_features + dim_idx];
if val > 0.5 {
plus_word |= 1u32 << bit;
} else if val < -0.5 {
minus_word |= 1u32 << bit;
}
}
input_plus[b * k_words + k] = plus_word;
input_minus[b * k_words + k] = minus_word;
}
}
let mut w_plus = vec![0u32; out_features * k_words];
let mut w_minus = vec![0u32; out_features * k_words];
w_plus[0] = 0xFFFFFFFF; w_plus[1] = 0x0000FFFF;
w_minus[2] = 0x00000000; w_minus[3] = 0x0000FFFF;
let scales = vec![1.0f32; out_features];
let mut output = vec![0.0f32; batch_size * out_features];
for batch_idx in 0..batch_size {
for out_idx in 0..out_features {
let weight_offset = out_idx * k_words;
let input_offset = batch_idx * k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
for k in 0..k_words {
let wp_bits = w_plus[weight_offset + k];
let wm_bits = w_minus[weight_offset + k];
let ip = input_plus[input_offset + k];
let im = input_minus[input_offset + k];
pos_sum += (wp_bits & ip).count_ones();
pos_sum += (wm_bits & im).count_ones();
neg_sum += (wp_bits & im).count_ones();
neg_sum += (wm_bits & ip).count_ones();
}
let dot = (pos_sum as i32 - neg_sum as i32) as f32;
output[batch_idx * out_features + out_idx] = dot * scales[out_idx];
}
}
let expected_b0_f0 = 48.0;
let expected_b0_f1 = 0.0;
let expected_b1_f0 = -48.0;
let expected_b1_f1 = 16.0;
assert!(
(output[0] - expected_b0_f0).abs() < 0.01,
"B0F0 mismatch: expected {}, got {}",
expected_b0_f0,
output[0]
);
assert!(
(output[1] - expected_b0_f1).abs() < 0.01,
"B0F1 mismatch: expected {}, got {}",
expected_b0_f1,
output[1]
);
assert!(
(output[2] - expected_b1_f0).abs() < 0.01,
"B1F0 mismatch: expected {}, got {}",
expected_b1_f0,
output[2]
);
assert!(
(output[3] - expected_b1_f1).abs() < 0.01,
"B1F1 mismatch: expected {}, got {}",
expected_b1_f1,
output[3]
);
}
#[test]
fn test_tiled_config_creation() {
let config = TiledTernaryMatmulConfig::rtx_5080_preset(8, 512, 32, 1024);
assert_eq!(config.k_words, 32);
assert_eq!(config.m, 8);
assert_eq!(config.n, 512);
assert_eq!(config.tile_k, 64);
assert_eq!(config.outputs_per_thread, 2);
assert_eq!(config.block_size, 256);
}
#[test]
fn test_tiled_launch_config() {
let config = TiledTernaryMatmulConfig::rtx_5080_preset(4, 512, 32, 1024);
let (cube_count, cube_dim) = get_tiled_launch_config(&config);
assert_eq!(cube_dim.x, 256);
assert_eq!(cube_dim.y, 1);
assert_eq!(cube_dim.z, 1);
if let CubeCount::Static(x, y, z) = cube_count {
assert_eq!(x, 4, "Grid X should match batch size");
assert_eq!(y, 1, "Grid Y should be ceil(512 / (256*2)) = 1");
assert_eq!(z, 1);
} else {
panic!("Expected Static cube count");
}
}
#[test]
fn test_tiled_presets() {
let rtx5080 = TiledTernaryMatmulConfig::rtx_5080_preset(16, 1024, 64, 2048);
assert_eq!(rtx5080.tile_k, 64, "RTX 5080 should use larger tiles");
let rtx3090 = TiledTernaryMatmulConfig::rtx_3090ti_preset(16, 1024, 64, 2048);
assert_eq!(rtx3090.tile_k, 32, "RTX 3090 Ti should use smaller tiles");
assert_eq!(rtx5080.block_size, rtx3090.block_size);
}
#[test]
fn test_tiled_kernel_correctness() {
let batch_size = 2;
let in_features = 128; let out_features = 4;
let k_words = 4;
let mut input_data = vec![0.0f32; batch_size * in_features];
for b in 0..batch_size {
for i in 0..in_features {
input_data[b * in_features + i] = if b == 0 {
if i % 2 == 0 {
1.0
} else {
-1.0
}
} else {
1.0
};
}
}
let mut input_plus = vec![0u32; batch_size * k_words];
let mut input_minus = vec![0u32; batch_size * k_words];
for b in 0..batch_size {
for k in 0..k_words {
let mut plus_word = 0u32;
let mut minus_word = 0u32;
for bit in 0..32 {
let dim_idx = k * 32 + bit;
let val = input_data[b * in_features + dim_idx];
if val > 0.5 {
plus_word |= 1u32 << bit;
} else if val < -0.5 {
minus_word |= 1u32 << bit;
}
}
input_plus[b * k_words + k] = plus_word;
input_minus[b * k_words + k] = minus_word;
}
}
let mut w_plus = vec![0u32; out_features * k_words];
let mut w_minus = vec![0u32; out_features * k_words];
for k in 0..k_words {
w_plus[0 * k_words + k] = 0xFFFFFFFF;
}
for k in 0..k_words {
w_minus[1 * k_words + k] = 0xFFFFFFFF;
}
for k in 0..k_words {
w_plus[2 * k_words + k] = 0xAAAAAAAA;
w_minus[2 * k_words + k] = 0x55555555;
}
w_plus[3 * k_words + 0] = 0xFFFFFFFF;
w_plus[3 * k_words + 1] = 0xFFFFFFFF;
w_minus[3 * k_words + 2] = 0xFFFFFFFF;
w_minus[3 * k_words + 3] = 0xFFFFFFFF;
let scales = vec![1.0f32; out_features];
let mut output = vec![0.0f32; batch_size * out_features];
let tile_k = 2;
for batch_idx in 0..batch_size {
for out_idx in 0..out_features {
let input_offset = batch_idx * k_words;
let weight_offset = out_idx * k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
let num_tiles = (k_words + tile_k - 1) / tile_k;
for tile_idx in 0..num_tiles {
let k_start = tile_idx * tile_k;
let k_end = std::cmp::min(k_start + tile_k, k_words);
for k in k_start..k_end {
let wp_bits = w_plus[weight_offset + k];
let wm_bits = w_minus[weight_offset + k];
let ip = input_plus[input_offset + k];
let im = input_minus[input_offset + k];
pos_sum += (wp_bits & ip).count_ones();
pos_sum += (wm_bits & im).count_ones();
neg_sum += (wp_bits & im).count_ones();
neg_sum += (wm_bits & ip).count_ones();
}
}
let dot = (pos_sum as i32 - neg_sum as i32) as f32;
output[batch_idx * out_features + out_idx] = dot * scales[out_idx];
}
}
let expected = vec![
0.0, 0.0, 256.0, 0.0, 256.0, -256.0, 0.0, 0.0, ];
for (i, (actual, expected)) in output.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - expected).abs() < 0.01,
"Mismatch at index {}: expected {}, got {}",
i,
expected,
actual
);
}
}
#[test]
fn test_vectorized_config_creation() {
let config = VectorizedTernaryMatmulConfig::rtx_5080_preset(8, 512, 32, 1024);
assert_eq!(config.k_words, 32);
assert_eq!(config.m, 8);
assert_eq!(config.n, 512);
assert_eq!(config.tile_k_lines, 16);
assert_eq!(config.tile_k(), 64); assert_eq!(config.outputs_per_thread, 2);
assert_eq!(config.block_size, 256);
}
#[test]
fn test_vectorized_launch_config() {
let config = VectorizedTernaryMatmulConfig::rtx_5080_preset(4, 512, 32, 1024);
let (cube_count, cube_dim) = get_vectorized_launch_config(&config);
assert_eq!(cube_dim.x, 256);
assert_eq!(cube_dim.y, 1);
assert_eq!(cube_dim.z, 1);
if let CubeCount::Static(x, y, z) = cube_count {
assert_eq!(x, 4);
assert_eq!(y, 1); assert_eq!(z, 1);
} else {
panic!("Expected Static cube count");
}
}
#[test]
fn test_vectorized_presets() {
let rtx5080 = VectorizedTernaryMatmulConfig::rtx_5080_preset(16, 1024, 64, 2048);
assert_eq!(rtx5080.tile_k_lines, 16, "RTX 5080 uses 16 Line<u32>");
assert_eq!(rtx5080.tile_k(), 64);
let rtx3090 = VectorizedTernaryMatmulConfig::rtx_3090ti_preset(16, 1024, 64, 2048);
assert_eq!(rtx3090.tile_k_lines, 8, "RTX 3090 Ti uses 8 Line<u32>");
assert_eq!(rtx3090.tile_k(), 32);
assert_eq!(rtx5080.block_size, rtx3090.block_size);
}
#[test]
fn test_vectorized_kernel_correctness() {
let batch_size = 2;
let in_features = 128; let out_features = 4;
let k_words = 4;
let mut input_data = vec![0.0f32; batch_size * in_features];
for b in 0..batch_size {
for i in 0..in_features {
input_data[b * in_features + i] = if b == 0 {
if i % 2 == 0 {
1.0
} else {
-1.0
}
} else {
1.0
};
}
}
let mut input_plus = vec![0u32; batch_size * k_words];
let mut input_minus = vec![0u32; batch_size * k_words];
for b in 0..batch_size {
for k in 0..k_words {
let mut plus_word = 0u32;
let mut minus_word = 0u32;
for bit in 0..32 {
let dim_idx = k * 32 + bit;
let val = input_data[b * in_features + dim_idx];
if val > 0.5 {
plus_word |= 1u32 << bit;
} else if val < -0.5 {
minus_word |= 1u32 << bit;
}
}
input_plus[b * k_words + k] = plus_word;
input_minus[b * k_words + k] = minus_word;
}
}
let mut w_plus = vec![0u32; out_features * k_words];
let mut w_minus = vec![0u32; out_features * k_words];
for k in 0..k_words {
w_plus[0 * k_words + k] = 0xFFFFFFFF;
}
for k in 0..k_words {
w_minus[1 * k_words + k] = 0xFFFFFFFF;
}
for k in 0..k_words {
w_plus[2 * k_words + k] = 0xAAAAAAAA;
w_minus[2 * k_words + k] = 0x55555555;
}
w_plus[3 * k_words + 0] = 0xFFFFFFFF;
w_plus[3 * k_words + 1] = 0xFFFFFFFF;
w_minus[3 * k_words + 2] = 0xFFFFFFFF;
w_minus[3 * k_words + 3] = 0xFFFFFFFF;
let scales = vec![1.0f32; out_features];
let mut output = vec![0.0f32; batch_size * out_features];
for batch_idx in 0..batch_size {
for out_idx in 0..out_features {
let input_offset = batch_idx * k_words;
let weight_offset = out_idx * k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
for k in 0..k_words {
let wp_bits = w_plus[weight_offset + k];
let wm_bits = w_minus[weight_offset + k];
let ip = input_plus[input_offset + k];
let im = input_minus[input_offset + k];
pos_sum += (wp_bits & ip).count_ones();
pos_sum += (wm_bits & im).count_ones();
neg_sum += (wp_bits & im).count_ones();
neg_sum += (wm_bits & ip).count_ones();
}
let dot = (pos_sum as i32 - neg_sum as i32) as f32;
output[batch_idx * out_features + out_idx] = dot * scales[out_idx];
}
}
let expected = vec![
0.0, 0.0, 256.0, 0.0, 256.0, -256.0, 0.0, 0.0, ];
for (i, (actual, expected)) in output.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - expected).abs() < 0.01,
"Vectorized mismatch at index {}: expected {}, got {}",
i,
expected,
actual
);
}
}
#[test]
fn test_vectorized_kernel_partial() {
let batch_size = 2;
let in_features = 96; let out_features = 2;
let k_words = 3;
let mut input_data = vec![0.0f32; batch_size * in_features];
for b in 0..batch_size {
for i in 0..in_features {
input_data[b * in_features + i] = if b == 0 { 1.0 } else { -1.0 };
}
}
let mut input_plus = vec![0u32; batch_size * k_words];
let mut input_minus = vec![0u32; batch_size * k_words];
for b in 0..batch_size {
for k in 0..k_words {
let mut plus_word = 0u32;
let mut minus_word = 0u32;
for bit in 0..32 {
let dim_idx = k * 32 + bit;
let val = input_data[b * in_features + dim_idx];
if val > 0.5 {
plus_word |= 1u32 << bit;
} else if val < -0.5 {
minus_word |= 1u32 << bit;
}
}
input_plus[b * k_words + k] = plus_word;
input_minus[b * k_words + k] = minus_word;
}
}
let mut w_plus = vec![0u32; out_features * k_words];
let mut w_minus = vec![0u32; out_features * k_words];
for k in 0..k_words {
w_plus[0 * k_words + k] = 0xFFFFFFFF;
}
for k in 0..k_words {
w_minus[1 * k_words + k] = 0xFFFFFFFF;
}
let scales = vec![1.0f32; out_features];
let mut output = vec![0.0f32; batch_size * out_features];
for batch_idx in 0..batch_size {
for out_idx in 0..out_features {
let input_offset = batch_idx * k_words;
let weight_offset = out_idx * k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
for k in 0..k_words {
let wp_bits = w_plus[weight_offset + k];
let wm_bits = w_minus[weight_offset + k];
let ip = input_plus[input_offset + k];
let im = input_minus[input_offset + k];
pos_sum += (wp_bits & ip).count_ones();
pos_sum += (wm_bits & im).count_ones();
neg_sum += (wp_bits & im).count_ones();
neg_sum += (wm_bits & ip).count_ones();
}
let dot = (pos_sum as i32 - neg_sum as i32) as f32;
output[batch_idx * out_features + out_idx] = dot * scales[out_idx];
}
}
let expected = vec![96.0, -96.0, -96.0, 96.0];
for (i, (actual, expected)) in output.iter().zip(expected.iter()).enumerate() {
assert!(
(actual - expected).abs() < 0.01,
"Partial vectorized mismatch at {}: expected {}, got {}",
i,
expected,
actual
);
}
}
#[test]
fn test_sparse_config_creation() {
let base = VectorizedTernaryMatmulConfig::rtx_5080_preset(8, 512, 32, 1024);
let sparse = SparseOptimizedConfig::from_sparsity(base, 0.95);
assert!(sparse.enable_plane_skipping);
assert_eq!(sparse.chunk_size, 64);
assert_eq!(sparse.words_per_chunk(), 2);
let dense = SparseOptimizedConfig::from_sparsity(base, 0.50);
assert!(!dense.enable_plane_skipping);
}
#[test]
fn test_sparse_presets() {
let rtx5080 = SparseOptimizedConfig::rtx_5080_sparse(16, 1024, 64, 2048, 0.95);
assert!(rtx5080.enable_plane_skipping);
assert_eq!(rtx5080.base.tile_k(), 64);
let rtx3090 = SparseOptimizedConfig::rtx_3090ti_sparse(16, 1024, 64, 2048, 0.95);
assert!(rtx3090.enable_plane_skipping);
assert_eq!(rtx3090.base.tile_k(), 32);
}
#[test]
fn test_sparse_launch_config() {
let base = VectorizedTernaryMatmulConfig::rtx_5080_preset(4, 512, 32, 1024);
let sparse = SparseOptimizedConfig::from_sparsity(base, 0.95);
let (cube_count, cube_dim) = get_sparse_launch_config(&sparse);
assert_eq!(cube_dim.x, 256);
if let CubeCount::Static(x, y, z) = cube_count {
assert_eq!(x, 4);
assert_eq!(y, 1);
assert_eq!(z, 1);
} else {
panic!("Expected Static cube count");
}
}
#[test]
fn test_sparse_kernel_correctness() {
let batch_size = 2;
let in_features = 128;
let out_features = 4;
let k_words = 4;
let mut input_data = vec![0.0f32; batch_size * in_features];
for b in 0..batch_size {
for i in 0..in_features {
input_data[b * in_features + i] = if b == 0 {
if i % 2 == 0 {
1.0
} else {
-1.0
}
} else {
1.0
};
}
}
let mut input_plus = vec![0u32; batch_size * k_words];
let mut input_minus = vec![0u32; batch_size * k_words];
for b in 0..batch_size {
for k in 0..k_words {
let mut plus_word = 0u32;
let mut minus_word = 0u32;
for bit in 0..32 {
let dim_idx = k * 32 + bit;
let val = input_data[b * in_features + dim_idx];
if val > 0.5 {
plus_word |= 1u32 << bit;
} else if val < -0.5 {
minus_word |= 1u32 << bit;
}
}
input_plus[b * k_words + k] = plus_word;
input_minus[b * k_words + k] = minus_word;
}
}
let mut w_plus = vec![0u32; out_features * k_words];
let mut w_minus = vec![0u32; out_features * k_words];
for out_idx in 0..out_features {
w_plus[out_idx * k_words] = 0xFFFFFFFF;
w_plus[out_idx * k_words + 1] = 0xFFFFFFFF;
}
let num_chunks = 2; let bitmap = vec![0x1u64; out_features];
let scales = vec![1.0f32; out_features];
let mut output = vec![0.0f32; batch_size * out_features];
for batch_idx in 0..batch_size {
for out_idx in 0..out_features {
let input_offset = batch_idx * k_words;
let weight_offset = out_idx * k_words;
let mut pos_sum = 0u32;
let mut neg_sum = 0u32;
for chunk_idx in 0..num_chunks {
let is_active = (bitmap[out_idx] & (1u64 << chunk_idx)) != 0;
if !is_active {
continue; }
let chunk_word_start = chunk_idx * 2;
let chunk_word_end = std::cmp::min(chunk_word_start + 2, k_words);
for k_word in chunk_word_start..chunk_word_end {
let wp_bits = w_plus[weight_offset + k_word];
let wm_bits = w_minus[weight_offset + k_word];
let ip = input_plus[input_offset + k_word];
let im = input_minus[input_offset + k_word];
pos_sum += (wp_bits & ip).count_ones();
pos_sum += (wm_bits & im).count_ones();
neg_sum += (wp_bits & im).count_ones();
neg_sum += (wm_bits & ip).count_ones();
}
}
let dot = (pos_sum as i32 - neg_sum as i32) as f32;
output[batch_idx * out_features + out_idx] = dot * scales[out_idx];
}
}
for (i, &val) in output.iter().enumerate() {
let batch = i / out_features;
if batch == 1 {
assert!(
(val - 64.0).abs() < 0.01,
"Sparse kernel mismatch at {}: expected 64.0, got {}",
i,
val
);
}
}
}
#[test]
fn test_sparse_kernel_skip_detection() {
let k_words = 8; let out_features = 2;
let mut w_plus = vec![0u32; out_features * k_words];
w_plus[0 * k_words + 0] = 0xFF; w_plus[1 * k_words + 2] = 0xFF; w_plus[1 * k_words + 4] = 0xFF;
let num_chunks = 4;
let mut bitmap = vec![0u64; out_features];
bitmap[0] = 0x1;
bitmap[1] = 0x6;
assert_eq!(
(bitmap[0] & 0x1),
0x1,
"Chunk 0 should be active for feature 0"
);
assert_eq!(
(bitmap[0] & 0x2),
0x0,
"Chunk 1 should be inactive for feature 0"
);
assert_eq!(
(bitmap[1] & 0x1),
0x0,
"Chunk 0 should be inactive for feature 1"
);
assert_eq!(
(bitmap[1] & 0x2),
0x2,
"Chunk 1 should be active for feature 1"
);
assert_eq!(
(bitmap[1] & 0x4),
0x4,
"Chunk 2 should be active for feature 1"
);
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SparsityPattern {
Uniform,
Clustered,
Layered,
}
#[derive(Debug, Clone)]
pub struct SparsityDistribution {
pub overall_sparsity: f32,
pub variance: f32,
pub pattern: SparsityPattern,
}
#[derive(Clone, Debug)]
pub struct DynamicPlaneSkippingConfig {
pub enable: bool,
pub initial_threshold: f32,
pub adaptive_range: (f32, f32),
pub min_chunk_size: usize,
pub max_chunk_size: usize,
pub profiling_overhead_limit: f32,
}
impl Default for DynamicPlaneSkippingConfig {
fn default() -> Self {
Self {
enable: false, initial_threshold: 0.90,
adaptive_range: (0.85, 0.95),
min_chunk_size: 32,
max_chunk_size: 128,
profiling_overhead_limit: 0.005, }
}
}
impl DynamicPlaneSkippingConfig {
pub fn with_enable(mut self, enable: bool) -> Self {
self.enable = enable;
self
}
pub fn with_pattern_detection(mut self, enable: bool) -> Self {
self.enable = enable;
self
}
}
pub fn analyze_tensor_sparsity_distribution(
tensor: &crate::kernels::ternary::types::TernaryTensor,
) -> SparsityDistribution {
use crate::kernels::ternary::types::TernaryTensor;
let out_features = tensor.shape()[0];
let in_features = tensor.shape()[1];
let chunk_size = 64; let num_chunks = (in_features + chunk_size - 1) / chunk_size;
let mut chunk_sparsities = Vec::new();
for chunk_idx in 0..num_chunks {
let start_dim = chunk_idx * chunk_size;
let end_dim = (start_dim + chunk_size).min(in_features);
let mut zero_count = 0;
let mut total_count = 0;
for row in 0..out_features {
for dim in start_dim..end_dim {
total_count += 1;
let word_idx = dim / 32;
let bit_idx = dim % 32;
let plus_data = tensor.plus_plane().storage();
let minus_data = tensor.minus_plane().storage();
let plus_idx = row * ((in_features + 31) / 32) + word_idx;
let minus_idx = row * ((in_features + 31) / 32) + word_idx;
if plus_idx < plus_data.len() && minus_idx < minus_data.len() {
let plus_word = plus_data[plus_idx];
let minus_word = minus_data[minus_idx];
let plus_bit = (plus_word >> bit_idx) & 1;
let minus_bit = (minus_word >> bit_idx) & 1;
if plus_bit == 0 && minus_bit == 0 {
zero_count += 1;
}
}
}
}
if total_count > 0 {
chunk_sparsities.push(zero_count as f32 / total_count as f32);
}
}
let overall_sparsity = if !chunk_sparsities.is_empty() {
chunk_sparsities.iter().sum::<f32>() / chunk_sparsities.len() as f32
} else {
0.0
};
let variance = if chunk_sparsities.len() > 1 {
let mean = overall_sparsity;
let sum_sq_diff: f32 = chunk_sparsities.iter().map(|&x| (x - mean).powi(2)).sum();
sum_sq_diff / chunk_sparsities.len() as f32
} else {
0.0
};
let pattern = if variance < 0.1 {
SparsityPattern::Uniform
} else if variance > 0.3 {
SparsityPattern::Clustered
} else {
SparsityPattern::Layered
};
SparsityDistribution {
overall_sparsity,
variance,
pattern,
}
}
#[allow(dead_code)]
pub fn adaptive_plane_skipping_matmul(
input: &candle_core::Tensor,
weights: &crate::kernels::ternary::types::TernaryTensor,
config: DynamicPlaneSkippingConfig,
) -> candle_core::Result<candle_core::Tensor> {
use crate::kernels::ternary::matmul::ternary_matmul;
if !config.enable {
return ternary_matmul(input, weights);
}
let distribution = analyze_tensor_sparsity_distribution(weights);
let adapted_chunk_size = match distribution.pattern {
SparsityPattern::Uniform => 64,
SparsityPattern::Clustered => config.min_chunk_size,
SparsityPattern::Layered => config.max_chunk_size,
};
let adapted_threshold = match distribution.pattern {
SparsityPattern::Uniform => config.initial_threshold,
SparsityPattern::Clustered => config.adaptive_range.0, SparsityPattern::Layered => config.adaptive_range.1, };
ternary_matmul(input, weights)
}
#[cfg(test)]
mod dynamic_plane_skipping_tests {
use super::*;
use crate::kernels::ternary::quantize::quantize_tensor_to_ternary;
use candle_core::{DType, Device, Tensor};
#[test]
fn test_dynamic_config_creation() {
let config = DynamicPlaneSkippingConfig::default();
assert!(!config.enable, "Should be disabled by default");
assert_eq!(config.initial_threshold, 0.90);
assert_eq!(config.adaptive_range, (0.85, 0.95));
assert_eq!(config.min_chunk_size, 32);
assert_eq!(config.max_chunk_size, 128);
let enabled_config = config.with_enable(true);
assert!(enabled_config.enable);
}
#[test]
fn test_analyze_uniform_sparsity() {
let device = Device::Cpu;
let data = Tensor::zeros((4, 128), DType::F32, &device).unwrap();
let mut data_vec = vec![0.0f32; 4 * 128];
for i in (0..data_vec.len()).step_by(10) {
data_vec[i] = 1.0;
}
let data = Tensor::from_vec(data_vec, (4, 128), &device).unwrap();
let ternary = quantize_tensor_to_ternary(&data, 0.5).unwrap();
let distribution = analyze_tensor_sparsity_distribution(&ternary);
assert!(distribution.overall_sparsity > 0.85);
assert!(
distribution.variance < 0.15,
"Uniform should have low variance"
);
assert_eq!(distribution.pattern, SparsityPattern::Uniform);
}
#[test]
fn test_analyze_clustered_sparsity() {
let device = Device::Cpu;
let mut data_vec = vec![0.0f32; 4 * 128];
for i in 0..32 {
data_vec[i] = 1.0;
}
data_vec[64] = 1.0;
data_vec[96] = 1.0;
let data = Tensor::from_vec(data_vec, (4, 128), &device).unwrap();
let ternary = quantize_tensor_to_ternary(&data, 0.5).unwrap();
let distribution = analyze_tensor_sparsity_distribution(&ternary);
assert!(
distribution.variance > 0.2,
"Clustered should have high variance"
);
}
#[test]
fn test_adaptive_matmul_shape() {
let device = Device::Cpu;
let batch = 2;
let in_features = 128;
let out_features = 64;
let input = Tensor::ones((batch, in_features), DType::F32, &device).unwrap();
let weights_data = Tensor::ones((out_features, in_features), DType::F32, &device).unwrap();
let weights = quantize_tensor_to_ternary(&weights_data, 0.5).unwrap();
let config = DynamicPlaneSkippingConfig::default().with_enable(true);
let output = adaptive_plane_skipping_matmul(&input, &weights, config).unwrap();
assert_eq!(output.dims(), &[batch, out_features]);
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SparseStorageFormat {
Bitsliced,
CSR,
ELL,
Compressed,
}
#[derive(Clone, Debug)]
pub struct ChunkOptimizationConfig {
pub enable_format_selection: bool,
pub on_gpu: bool,
pub force_format: Option<SparseStorageFormat>,
}
impl ChunkOptimizationConfig {
pub fn default() -> Self {
Self {
enable_format_selection: true,
on_gpu: false,
force_format: None,
}
}
pub fn for_gpu() -> Self {
Self {
enable_format_selection: true,
on_gpu: true,
force_format: None,
}
}
pub fn with_format(mut self, format: SparseStorageFormat) -> Self {
self.force_format = Some(format);
self
}
}
#[derive(Clone, Debug)]
pub struct SparsityStats {
pub sparsity: f32,
pub row_variance: f32,
pub max_nnz_per_row: usize,
pub avg_nnz_per_row: f32,
}
pub fn analyze_sparsity_stats(tensor: &TernaryTensor) -> SparsityStats {
let out_features = tensor.plus_plane.dims()[0];
let in_features = tensor.plus_plane.dims()[1] * BITS_PER_U32 as usize;
let plus_data = tensor.plus_plane.to_vec2::<u32>().unwrap();
let minus_data = tensor.minus_plane.to_vec2::<u32>().unwrap();
let mut total_nonzero = 0;
let mut nnz_per_row = Vec::new();
for row_idx in 0..out_features {
let mut row_nnz = 0;
for word_idx in 0..plus_data[row_idx].len() {
let plus_word = plus_data[row_idx][word_idx];
let minus_word = minus_data[row_idx][word_idx];
row_nnz += (plus_word | minus_word).count_ones() as usize;
}
total_nonzero += row_nnz;
nnz_per_row.push(row_nnz);
}
let total_elements = out_features * in_features;
let sparsity = 1.0 - (total_nonzero as f32 / total_elements as f32);
let avg_nnz = nnz_per_row.iter().sum::<usize>() as f32 / nnz_per_row.len() as f32;
let variance = nnz_per_row
.iter()
.map(|&x| (x as f32 - avg_nnz).powi(2))
.sum::<f32>()
/ nnz_per_row.len() as f32;
let row_variance = variance.sqrt() / avg_nnz.max(1.0);
let max_nnz_per_row = *nnz_per_row.iter().max().unwrap_or(&0);
SparsityStats {
sparsity,
row_variance,
max_nnz_per_row,
avg_nnz_per_row: avg_nnz,
}
}
pub fn format_selection_heuristic(
sparsity: f32,
row_variance: f32,
on_gpu: bool,
) -> SparseStorageFormat {
if sparsity < 0.50 {
SparseStorageFormat::Bitsliced
} else if sparsity > 0.98 {
SparseStorageFormat::Compressed
} else if sparsity > 0.95 && row_variance < 0.2 {
SparseStorageFormat::CSR
} else if on_gpu && row_variance < 0.1 {
SparseStorageFormat::ELL
} else {
SparseStorageFormat::Bitsliced
}
}
#[derive(Clone, Debug)]
pub struct CSRFormat {
pub row_ptr: Vec<usize>,
pub col_indices: Vec<usize>,
pub values: Vec<i8>, }
pub fn convert_to_csr_format(tensor: &TernaryTensor) -> CSRFormat {
let out_features = tensor.plus_plane.dims()[0];
let in_features = tensor.plus_plane.dims()[1] * BITS_PER_U32 as usize;
let plus_data = tensor.plus_plane.to_vec2::<u32>().unwrap();
let minus_data = tensor.minus_plane.to_vec2::<u32>().unwrap();
let mut row_ptr = vec![0];
let mut col_indices = Vec::new();
let mut values = Vec::new();
for row_idx in 0..out_features {
for word_idx in 0..plus_data[row_idx].len() {
let plus_word = plus_data[row_idx][word_idx];
let minus_word = minus_data[row_idx][word_idx];
for bit in 0..BITS_PER_U32 {
let col = word_idx * BITS_PER_U32 as usize + bit as usize;
if col >= in_features {
break;
}
let is_plus = (plus_word & (1u32 << bit)) != 0;
let is_minus = (minus_word & (1u32 << bit)) != 0;
if is_plus || is_minus {
col_indices.push(col);
values.push(if is_plus { 1 } else { -1 });
}
}
}
row_ptr.push(values.len());
}
CSRFormat {
row_ptr,
col_indices,
values,
}
}
#[derive(Clone, Debug)]
pub struct ELLFormat {
pub max_nnz_per_row: usize,
pub col_indices: Vec<Vec<usize>>, pub values: Vec<Vec<i8>>, }
pub fn convert_to_ell_format(tensor: &TernaryTensor) -> ELLFormat {
let stats = analyze_sparsity_stats(tensor);
let out_features = tensor.plus_plane.dims()[0];
let in_features = tensor.plus_plane.dims()[1] * BITS_PER_U32 as usize;
let plus_data = tensor.plus_plane.to_vec2::<u32>().unwrap();
let minus_data = tensor.minus_plane.to_vec2::<u32>().unwrap();
let max_nnz = stats.max_nnz_per_row;
let mut col_indices = vec![vec![0; max_nnz]; out_features];
let mut values = vec![vec![0i8; max_nnz]; out_features];
for row_idx in 0..out_features {
let mut nnz_count = 0;
for word_idx in 0..plus_data[row_idx].len() {
let plus_word = plus_data[row_idx][word_idx];
let minus_word = minus_data[row_idx][word_idx];
for bit in 0..BITS_PER_U32 {
let col = word_idx * BITS_PER_U32 as usize + bit as usize;
if col >= in_features {
break;
}
let is_plus = (plus_word & (1u32 << bit)) != 0;
let is_minus = (minus_word & (1u32 << bit)) != 0;
if is_plus || is_minus {
if nnz_count < max_nnz {
col_indices[row_idx][nnz_count] = col;
values[row_idx][nnz_count] = if is_plus { 1 } else { -1 };
nnz_count += 1;
}
}
}
}
}
ELLFormat {
max_nnz_per_row: max_nnz,
col_indices,
values,
}
}
#[derive(Clone, Debug)]
pub struct CompressedFormat {
pub runs: Vec<Vec<(usize, i8)>>, }
pub fn convert_to_compressed_format(tensor: &TernaryTensor) -> CompressedFormat {
let out_features = tensor.plus_plane.dims()[0];
let in_features = tensor.plus_plane.dims()[1] * BITS_PER_U32 as usize;
let plus_data = tensor.plus_plane.to_vec2::<u32>().unwrap();
let minus_data = tensor.minus_plane.to_vec2::<u32>().unwrap();
let mut runs = Vec::new();
for row_idx in 0..out_features {
let mut row_runs = Vec::new();
let mut current_value: Option<i8> = None;
let mut run_start = 0;
for word_idx in 0..plus_data[row_idx].len() {
let plus_word = plus_data[row_idx][word_idx];
let minus_word = minus_data[row_idx][word_idx];
for bit in 0..BITS_PER_U32 {
let col = word_idx * BITS_PER_U32 as usize + bit as usize;
if col >= in_features {
break;
}
let is_plus = (plus_word & (1u32 << bit)) != 0;
let is_minus = (minus_word & (1u32 << bit)) != 0;
let value = if is_plus {
Some(1)
} else if is_minus {
Some(-1)
} else {
None
};
if value != current_value {
if let Some(v) = current_value {
row_runs.push((run_start, v));
}
current_value = value;
run_start = col;
}
}
}
if let Some(v) = current_value {
row_runs.push((run_start, v));
}
runs.push(row_runs);
}
CompressedFormat { runs }
}
pub fn chunk_optimized_ternary_matmul(
input: &Tensor,
weights: &TernaryTensor,
config: ChunkOptimizationConfig,
) -> Result<Tensor> {
let stats = analyze_sparsity_stats(weights);
let format = if let Some(forced) = config.force_format {
forced
} else if config.enable_format_selection {
format_selection_heuristic(stats.sparsity, stats.row_variance, config.on_gpu)
} else {
SparseStorageFormat::Bitsliced
};
match format {
SparseStorageFormat::Bitsliced => {
ternary_matmul_cpu(input, weights)
}
SparseStorageFormat::CSR => {
ternary_matmul_cpu(input, weights)
}
SparseStorageFormat::ELL => {
ternary_matmul_cpu(input, weights)
}
SparseStorageFormat::Compressed => {
ternary_matmul_cpu(input, weights)
}
}
}
#[cfg(test)]
mod tests_chunk_optimization {
use super::*;
#[test]
fn test_format_selection_dense() {
let format = format_selection_heuristic(0.30, 0.1, false);
assert_eq!(format, SparseStorageFormat::Bitsliced);
}
#[test]
fn test_format_selection_very_sparse() {
let format = format_selection_heuristic(0.96, 0.15, false);
assert_eq!(format, SparseStorageFormat::CSR);
}
#[cfg(feature = "cuda")]
#[test]
fn test_convert_to_csr() {
use crate::kernels::quantization::quantize_tensor_to_ternary;
use crate::ternary::TernaryTensor;
let device = candle_core::Device::Cpu;
let data = Tensor::from_slice(
&[1.0f32, 0.0, -1.0, 0.0, 1.0, 0.0, 0.0, -1.0],
&[2, 4],
&device,
)
.unwrap();
let ternary = quantize_tensor_to_ternary(&data, 0.5).unwrap();
let csr = convert_to_csr_format(&ternary);
assert_eq!(csr.row_ptr[0], 0);
assert_eq!(csr.row_ptr[1], 2);
assert_eq!(csr.row_ptr[2], 4);
assert_eq!(csr.values.len(), 4);
}
#[cfg(feature = "cuda")]
#[test]
fn test_chunk_optimized_matmul_shape() {
use crate::kernels::quantization::quantize_tensor_to_ternary;
let device = candle_core::Device::Cpu;
let batch = 2;
let in_features = 64;
let out_features = 32;
let input = Tensor::ones((batch, in_features), DType::F32, &device).unwrap();
let weights_data = Tensor::ones((out_features, in_features), DType::F32, &device).unwrap();
let weights = quantize_tensor_to_ternary(&weights_data, 0.5).unwrap();
let config = ChunkOptimizationConfig::default();
let output = chunk_optimized_ternary_matmul(&input, &weights, config).unwrap();
assert_eq!(output.dims(), &[batch, out_features]);
}
}
#[derive(Clone, Debug)]
pub struct SparsityProfile {
pub overall_sparsity: f32,
pub per_row_sparsity: Vec<f32>,
pub sparsity_pattern: SparsityPattern,
pub recommended_format: SparseStorageFormat,
pub recommended_chunk_size: usize,
pub estimated_memory_savings: f32,
pub estimated_speedup: f32,
}
pub fn profile_tensor_sparsity(tensor: &TernaryTensor, on_gpu: bool) -> SparsityProfile {
let stats = analyze_sparsity_stats(tensor);
let distribution = analyze_tensor_sparsity_distribution(tensor);
let pattern = if distribution.variance < 0.1 {
SparsityPattern::Uniform
} else if distribution.variance > 0.3 {
SparsityPattern::Clustered
} else {
SparsityPattern::Layered
};
let recommended_format = format_selection_heuristic(stats.sparsity, stats.row_variance, on_gpu);
let recommended_chunk_size = match pattern {
SparsityPattern::Uniform => 64,
SparsityPattern::Clustered => 32,
SparsityPattern::Layered => 128,
};
let estimated_memory_savings = match recommended_format {
SparseStorageFormat::Bitsliced => 1.0,
SparseStorageFormat::CSR => {
if stats.sparsity > 0.95 {
3.0
} else {
1.5
}
}
SparseStorageFormat::ELL => 1.2,
SparseStorageFormat::Compressed => {
if stats.sparsity > 0.98 {
5.0
} else {
2.0
}
}
};
let estimated_speedup = if stats.sparsity > 0.95 {
2.0 + (stats.sparsity - 0.95) * 10.0 } else {
1.0 + stats.sparsity * 0.5 };
let out_features = tensor.plus_plane.dims()[0];
let in_features = tensor.plus_plane.dims()[1] * BITS_PER_U32 as usize;
let plus_data = tensor.plus_plane.to_vec2::<u32>().unwrap();
let minus_data = tensor.minus_plane.to_vec2::<u32>().unwrap();
let mut per_row_sparsity = Vec::new();
for row_idx in 0..out_features {
let mut row_nnz = 0;
for word_idx in 0..plus_data[row_idx].len() {
let plus_word = plus_data[row_idx][word_idx];
let minus_word = minus_data[row_idx][word_idx];
row_nnz += (plus_word | minus_word).count_ones() as usize;
}
let row_sparsity = 1.0 - (row_nnz as f32 / in_features as f32);
per_row_sparsity.push(row_sparsity);
}
SparsityProfile {
overall_sparsity: stats.sparsity,
per_row_sparsity,
sparsity_pattern: pattern,
recommended_format,
recommended_chunk_size,
estimated_memory_savings,
estimated_speedup,
}
}
pub fn generate_optimization_recommendations(profile: &SparsityProfile) -> Vec<String> {
let mut recommendations = Vec::new();
recommendations.push(format!(
"Use {:?} storage format for {:.1}% sparsity",
profile.recommended_format,
profile.overall_sparsity * 100.0
));
recommendations.push(format!(
"Use chunk size of {} dimensions based on {:?} pattern",
profile.recommended_chunk_size, profile.sparsity_pattern
));
if profile.estimated_memory_savings > 1.5 {
recommendations.push(format!(
"Expected {:.1}x memory reduction with recommended format",
profile.estimated_memory_savings
));
}
if profile.estimated_speedup > 1.5 {
recommendations.push(format!(
"Expected {:.1}x speedup with plane skipping optimization",
profile.estimated_speedup
));
}
match profile.sparsity_pattern {
SparsityPattern::Uniform => {
recommendations.push("Uniform sparsity: standard optimizations effective".to_string());
}
SparsityPattern::Clustered => {
recommendations
.push("Clustered sparsity: use smaller chunks for precision".to_string());
}
SparsityPattern::Layered => {
recommendations
.push("Layered sparsity: consider per-layer format selection".to_string());
}
}
if profile.overall_sparsity > 0.98 {
recommendations.push("Extreme sparsity: consider RLE compression".to_string());
}
recommendations
}
#[cfg(test)]
mod tests_sparsity_profiler {
use super::*;
#[cfg(feature = "cuda")]
#[test]
fn test_profile_tensor_sparsity() {
use crate::kernels::quantization::quantize_tensor_to_ternary;
let device = candle_core::Device::Cpu;
let mut data_vec = vec![0.0f32; 1000];
for i in 0..100 {
data_vec[i * 10] = 1.0; }
let data = Tensor::from_slice(&data_vec, &[10, 100], &device).unwrap();
let ternary = quantize_tensor_to_ternary(&data, 0.5).unwrap();
let profile = profile_tensor_sparsity(&ternary, false);
assert!(profile.overall_sparsity > 0.85);
assert!(profile.overall_sparsity < 0.95);
assert!(profile.estimated_speedup > 1.0);
}
#[cfg(feature = "cuda")]
#[test]
fn test_generate_recommendations() {
use crate::kernels::quantization::quantize_tensor_to_ternary;
let device = candle_core::Device::Cpu;
let data = Tensor::zeros(&[10, 100], DType::F32, &device).unwrap();
let ternary = quantize_tensor_to_ternary(&data, 0.5).unwrap();
let profile = profile_tensor_sparsity(&ternary, false);
let recommendations = generate_optimization_recommendations(&profile);
assert!(!recommendations.is_empty());
assert!(recommendations.iter().any(|r| r.contains("storage format")));
}
}