use super::config::TernaryConfig;
use candle_core::{Device, Shape, Tensor};
pub use trit_vsa::{PackedTritVec, Trit};
#[derive(Debug, Clone)]
pub struct TernaryPlanes {
pub plus: Vec<u32>,
pub minus: Vec<u32>,
pub num_dims: usize,
}
impl TernaryPlanes {
#[must_use]
pub fn new(num_dims: usize) -> Self {
let num_words = num_dims.div_ceil(32);
Self {
plus: vec![0u32; num_words],
minus: vec![0u32; num_words],
num_dims,
}
}
#[must_use]
pub fn num_words(&self) -> usize {
self.plus.len()
}
pub fn set(&mut self, dim: usize, value: i8) {
assert!(dim < self.num_dims, "dimension out of bounds");
assert!(
(-1..=1).contains(&value),
"value must be -1, 0, or +1, got {value}"
);
let word_idx = dim / 32;
let bit_idx = dim % 32;
let mask = 1u32 << bit_idx;
self.plus[word_idx] &= !mask;
self.minus[word_idx] &= !mask;
match value {
1 => self.plus[word_idx] |= mask,
-1 => self.minus[word_idx] |= mask,
0 => {} _ => unreachable!(),
}
}
#[must_use]
pub fn get(&self, dim: usize) -> i8 {
assert!(dim < self.num_dims, "dimension out of bounds");
let word_idx = dim / 32;
let bit_idx = dim % 32;
let mask = 1u32 << bit_idx;
let is_plus = (self.plus[word_idx] & mask) != 0;
let is_minus = (self.minus[word_idx] & mask) != 0;
debug_assert!(
!(is_plus && is_minus),
"invalid state: both planes set at dim {dim}"
);
if is_plus {
1
} else if is_minus {
-1
} else {
0
}
}
#[must_use]
pub fn count_nonzero(&self) -> usize {
let plus_count: u32 = self.plus.iter().map(|w| w.count_ones()).sum();
let minus_count: u32 = self.minus.iter().map(|w| w.count_ones()).sum();
(plus_count + minus_count) as usize
}
#[must_use]
pub fn sparsity(&self) -> f32 {
#[allow(clippy::cast_precision_loss)]
{
1.0 - (self.count_nonzero() as f32 / self.num_dims as f32)
}
}
#[must_use]
pub fn dot(&self, other: &TernaryPlanes) -> i32 {
assert_eq!(
self.num_words(),
other.num_words(),
"planes must have same size"
);
let mut result: i32 = 0;
for i in 0..self.num_words() {
let pp = (self.plus[i] & other.plus[i]).count_ones().cast_signed();
let mm = (self.minus[i] & other.minus[i]).count_ones().cast_signed();
let pm = (self.plus[i] & other.minus[i]).count_ones().cast_signed();
let mp = (self.minus[i] & other.plus[i]).count_ones().cast_signed();
result += pp + mm - pm - mp;
}
result
}
#[must_use]
pub fn to_packed_trit_vec(&self) -> PackedTritVec {
PackedTritVec::from_planes(self.plus.clone(), self.minus.clone(), self.num_dims)
.expect("TernaryPlanes invariants guarantee valid planes")
}
#[must_use]
pub fn from_packed_trit_vec(packed: &PackedTritVec) -> Self {
Self {
plus: packed.plus_plane().to_vec(),
minus: packed.minus_plane().to_vec(),
num_dims: packed.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct SparsityMetadata {
pub active_chunks: Vec<u64>,
pub chunk_size: usize,
pub num_chunks: usize,
}
impl SparsityMetadata {
#[must_use]
pub fn from_planes(planes: &TernaryPlanes, chunk_size: usize) -> Self {
let num_chunks = planes.num_dims.div_ceil(chunk_size);
let num_words = num_chunks.div_ceil(64);
let mut active_chunks = vec![0u64; num_words];
for chunk_idx in 0..num_chunks {
let start_dim = chunk_idx * chunk_size;
let end_dim = (start_dim + chunk_size).min(planes.num_dims);
let start_word = start_dim / 32;
let end_word = end_dim.div_ceil(32);
let mut is_active = false;
for word_idx in start_word..end_word {
if planes.plus[word_idx] != 0 || planes.minus[word_idx] != 0 {
is_active = true;
break;
}
}
if is_active {
let bitmap_idx = chunk_idx / 64;
let bit_idx = chunk_idx % 64;
active_chunks[bitmap_idx] |= 1u64 << bit_idx;
}
}
Self {
active_chunks,
chunk_size,
num_chunks,
}
}
#[must_use]
pub fn is_chunk_active(&self, chunk_idx: usize) -> bool {
if chunk_idx >= self.num_chunks {
return false;
}
let bitmap_idx = chunk_idx / 64;
let bit_idx = chunk_idx % 64;
(self.active_chunks[bitmap_idx] & (1u64 << bit_idx)) != 0
}
#[must_use]
pub fn active_count(&self) -> usize {
self.active_chunks
.iter()
.map(|w| w.count_ones() as usize)
.sum()
}
#[must_use]
pub fn chunk_sparsity(&self) -> f32 {
#[allow(clippy::cast_precision_loss)]
{
1.0 - (self.active_count() as f32 / self.num_chunks as f32)
}
}
}
#[derive(Debug, Clone)]
pub struct TernaryTensor {
pub plus_plane: Vec<u32>,
pub minus_plane: Vec<u32>,
pub scales: Vec<f32>,
pub shape: (usize, usize),
pub k_words: usize,
pub sparsity_meta: Option<Vec<SparsityMetadata>>,
sparsity: f32,
}
impl TernaryTensor {
#[must_use]
pub fn new(
plus_plane: Vec<u32>,
minus_plane: Vec<u32>,
scales: Vec<f32>,
shape: (usize, usize),
) -> Self {
let k_words = shape.1.div_ceil(32);
let expected_len = shape.0 * k_words;
assert_eq!(
plus_plane.len(),
expected_len,
"plus_plane size mismatch: expected {expected_len}, got {}",
plus_plane.len()
);
assert_eq!(
minus_plane.len(),
expected_len,
"minus_plane size mismatch: expected {expected_len}, got {}",
minus_plane.len()
);
assert_eq!(
scales.len(),
shape.0,
"scales size mismatch: expected {}, got {}",
shape.0,
scales.len()
);
let plus_ones: u32 = plus_plane.iter().map(|w| w.count_ones()).sum();
let minus_ones: u32 = minus_plane.iter().map(|w| w.count_ones()).sum();
let total_elements = shape.0 * shape.1;
let nonzero = plus_ones + minus_ones;
#[allow(clippy::cast_precision_loss)]
let sparsity = 1.0 - (nonzero as f32 / total_elements as f32);
Self {
plus_plane,
minus_plane,
scales,
shape,
k_words,
sparsity_meta: None,
sparsity,
}
}
#[must_use]
pub const fn dims(&self) -> (usize, usize) {
self.shape
}
#[must_use]
pub const fn sparsity(&self) -> f32 {
self.sparsity
}
#[must_use]
pub fn is_sparse_enough(&self, config: &TernaryConfig) -> bool {
self.sparsity >= config.sparsity_threshold
}
pub fn build_sparsity_metadata(&mut self, chunk_size: usize) {
let mut metadata = Vec::with_capacity(self.shape.0);
for row in 0..self.shape.0 {
let row_offset = row * self.k_words;
let plus_row: Vec<u32> =
self.plus_plane[row_offset..row_offset + self.k_words].to_vec();
let minus_row: Vec<u32> =
self.minus_plane[row_offset..row_offset + self.k_words].to_vec();
let planes = TernaryPlanes {
plus: plus_row,
minus: minus_row,
num_dims: self.shape.1,
};
metadata.push(SparsityMetadata::from_planes(&planes, chunk_size));
}
self.sparsity_meta = Some(metadata);
}
#[must_use]
pub fn get_row_planes(&self, row: usize) -> TernaryPlanes {
assert!(row < self.shape.0, "row out of bounds");
let row_offset = row * self.k_words;
TernaryPlanes {
plus: self.plus_plane[row_offset..row_offset + self.k_words].to_vec(),
minus: self.minus_plane[row_offset..row_offset + self.k_words].to_vec(),
num_dims: self.shape.1,
}
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
let plane_bytes = self.plus_plane.len() * 4 * 2; let scale_bytes = self.scales.len() * 4; let meta_bytes = self
.sparsity_meta
.as_ref()
.map_or(0, |m| m.iter().map(|s| s.active_chunks.len() * 8).sum());
plane_bytes + scale_bytes + meta_bytes
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
let fp32_bytes = self.shape.0 * self.shape.1 * 4;
#[allow(clippy::cast_precision_loss)]
{
fp32_bytes as f32 / self.memory_bytes() as f32
}
}
pub fn plus_plane_tensor(&self, device: &Device) -> candle_core::Result<Tensor> {
let shape = Shape::from_dims(&[self.shape.0, self.k_words]);
let data: Vec<f32> = self.plus_plane.iter().map(|&x| f32::from_bits(x)).collect();
Tensor::from_vec(data, shape, device)
}
pub fn minus_plane_tensor(&self, device: &Device) -> candle_core::Result<Tensor> {
let shape = Shape::from_dims(&[self.shape.0, self.k_words]);
let data: Vec<f32> = self
.minus_plane
.iter()
.map(|&x| f32::from_bits(x))
.collect();
Tensor::from_vec(data, shape, device)
}
pub fn scales_tensor(&self, device: &Device) -> candle_core::Result<Tensor> {
Tensor::from_vec(self.scales.clone(), self.shape.0, device)
}
pub fn modify_dim(&mut self, row: usize, col: usize, new_val: i8) {
assert!(
row < self.shape.0,
"row {} out of bounds (max {})",
row,
self.shape.0 - 1
);
assert!(
col < self.shape.1,
"col {} out of bounds (max {})",
col,
self.shape.1 - 1
);
assert!(
(-1..=1).contains(&new_val),
"new_val must be -1, 0, or +1, got {new_val}"
);
let word_idx = col / 32;
let bit_idx = col % 32;
let mask = 1u32 << bit_idx;
let plane_idx = row * self.k_words + word_idx;
self.plus_plane[plane_idx] &= !mask;
self.minus_plane[plane_idx] &= !mask;
match new_val {
1 => self.plus_plane[plane_idx] |= mask,
-1 => self.minus_plane[plane_idx] |= mask,
0 => {} _ => unreachable!(),
}
}
#[must_use]
pub fn get_dim(&self, row: usize, col: usize) -> i8 {
assert!(
row < self.shape.0,
"row index {} out of bounds for number of rows {}",
row,
self.shape.0
);
assert!(
col < self.shape.1,
"column index {} out of bounds for number of columns {}",
col,
self.shape.1
);
let word_idx = col / 32;
let bit_idx = col % 32;
let mask = 1u32 << bit_idx;
let plane_idx = row * self.k_words + word_idx;
let is_plus = (self.plus_plane[plane_idx] & mask) != 0;
let is_minus = (self.minus_plane[plane_idx] & mask) != 0;
debug_assert!(!(is_plus && is_minus), "invalid state: both planes set");
if is_plus {
1
} else if is_minus {
-1
} else {
0
}
}
pub fn recalculate_sparsity(&mut self) {
let plus_ones: u32 = self.plus_plane.iter().map(|w| w.count_ones()).sum();
let minus_ones: u32 = self.minus_plane.iter().map(|w| w.count_ones()).sum();
let total_elements = self.shape.0 * self.shape.1;
let nonzero = plus_ones + minus_ones;
#[allow(clippy::cast_precision_loss)]
{
self.sparsity = 1.0 - (nonzero as f32 / total_elements as f32);
}
}
pub fn prune_below_threshold(&mut self, _threshold: f32) -> usize {
0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ternary_planes_basic() {
let mut planes = TernaryPlanes::new(100);
planes.set(0, 1);
planes.set(1, -1);
planes.set(2, 0);
planes.set(50, 1);
planes.set(99, -1);
assert_eq!(planes.get(0), 1);
assert_eq!(planes.get(1), -1);
assert_eq!(planes.get(2), 0);
assert_eq!(planes.get(50), 1);
assert_eq!(planes.get(99), -1);
assert_eq!(planes.get(10), 0); }
#[test]
fn test_ternary_planes_dot_product() {
let mut a = TernaryPlanes::new(64);
let mut b = TernaryPlanes::new(64);
a.set(0, 1);
a.set(1, -1);
a.set(3, 1);
b.set(0, 1);
b.set(1, 1);
b.set(2, -1);
assert_eq!(a.dot(&b), 0);
b.set(1, -1); assert_eq!(a.dot(&b), 2);
}
#[test]
fn test_sparsity_calculation() {
let mut planes = TernaryPlanes::new(100);
planes.set(0, 1);
planes.set(10, -1);
planes.set(20, 1);
planes.set(50, -1);
planes.set(99, 1);
assert_eq!(planes.count_nonzero(), 5);
assert!((planes.sparsity() - 0.95).abs() < 0.001);
}
#[test]
fn test_sparsity_metadata() {
let mut planes = TernaryPlanes::new(1000);
for i in 0..10 {
planes.set(i, 1);
}
let meta = SparsityMetadata::from_planes(&planes, 100);
assert!(meta.is_chunk_active(0)); assert!(!meta.is_chunk_active(1)); assert!(!meta.is_chunk_active(9)); }
#[test]
fn test_ternary_tensor_creation() {
let shape = (4, 64); let k_words = 2;
let plus = vec![0u32; 4 * k_words];
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let tensor = TernaryTensor::new(plus, minus, scales, shape);
assert_eq!(tensor.dims(), (4, 64));
assert_eq!(tensor.k_words, 2);
assert!((tensor.sparsity() - 1.0).abs() < 0.001); }
#[test]
fn test_compression_ratio() {
let shape = (1024, 4096);
let k_words = 128;
let plus = vec![0u32; 1024 * k_words];
let minus = vec![0u32; 1024 * k_words];
let scales = vec![1.0f32; 1024];
let tensor = TernaryTensor::new(plus, minus, scales, shape);
let ratio = tensor.compression_ratio();
assert!(ratio > 10.0 && ratio < 20.0);
}
#[test]
fn test_modify_dim_basic() {
let shape = (4, 64);
let k_words = 2;
let plus = vec![0u32; 4 * k_words];
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let mut tensor = TernaryTensor::new(plus, minus, scales, shape);
assert_eq!(tensor.get_dim(0, 0), 0);
assert_eq!(tensor.get_dim(0, 31), 0);
assert_eq!(tensor.get_dim(0, 32), 0);
tensor.modify_dim(0, 0, 1);
assert_eq!(tensor.get_dim(0, 0), 1);
tensor.modify_dim(0, 0, -1);
assert_eq!(tensor.get_dim(0, 0), -1);
tensor.modify_dim(0, 0, 0);
assert_eq!(tensor.get_dim(0, 0), 0);
tensor.modify_dim(0, 32, 1);
assert_eq!(tensor.get_dim(0, 32), 1);
tensor.modify_dim(0, 63, -1);
assert_eq!(tensor.get_dim(0, 63), -1);
}
#[test]
fn test_modify_dim_different_rows() {
let shape = (4, 64);
let k_words = 2;
let plus = vec![0u32; 4 * k_words];
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let mut tensor = TernaryTensor::new(plus, minus, scales, shape);
tensor.modify_dim(0, 0, 1);
tensor.modify_dim(1, 0, -1);
tensor.modify_dim(2, 0, 1);
tensor.modify_dim(3, 0, 0);
assert_eq!(tensor.get_dim(0, 0), 1);
assert_eq!(tensor.get_dim(1, 0), -1);
assert_eq!(tensor.get_dim(2, 0), 1);
assert_eq!(tensor.get_dim(3, 0), 0);
}
#[test]
fn test_recalculate_sparsity() {
let shape = (4, 64);
let k_words = 2;
let plus = vec![0u32; 4 * k_words];
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let mut tensor = TernaryTensor::new(plus, minus, scales, shape);
assert!((tensor.sparsity() - 1.0).abs() < 0.001);
for i in 0..10 {
tensor.modify_dim(0, i, 1);
}
tensor.recalculate_sparsity();
let expected = 1.0 - (10.0 / 256.0);
assert!((tensor.sparsity() - expected).abs() < 0.001);
}
#[test]
#[should_panic(expected = "row")]
fn test_modify_dim_row_bounds() {
let shape = (4, 64);
let k_words = 2;
let plus = vec![0u32; 4 * k_words];
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let mut tensor = TernaryTensor::new(plus, minus, scales, shape);
tensor.modify_dim(4, 0, 1); }
#[test]
#[should_panic(expected = "col")]
fn test_modify_dim_col_bounds() {
let shape = (4, 64);
let k_words = 2;
let plus = vec![0u32; 4 * k_words];
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let mut tensor = TernaryTensor::new(plus, minus, scales, shape);
tensor.modify_dim(0, 64, 1); }
#[test]
#[should_panic(expected = "new_val")]
fn test_modify_dim_invalid_value() {
let shape = (4, 64);
let k_words = 2;
let plus = vec![0u32; 4 * k_words];
let minus = vec![0u32; 4 * k_words];
let scales = vec![1.0f32; 4];
let mut tensor = TernaryTensor::new(plus, minus, scales, shape);
tensor.modify_dim(0, 0, 2); }
#[test]
fn test_ternary_planes_to_packed_trit_vec_roundtrip() {
let mut planes = TernaryPlanes::new(100);
planes.set(0, 1);
planes.set(1, -1);
planes.set(50, 1);
planes.set(99, -1);
let packed = planes.to_packed_trit_vec();
assert_eq!(packed.get(0), Trit::P);
assert_eq!(packed.get(1), Trit::N);
assert_eq!(packed.get(2), Trit::Z);
assert_eq!(packed.get(50), Trit::P);
assert_eq!(packed.get(99), Trit::N);
let planes2 = TernaryPlanes::from_packed_trit_vec(&packed);
assert_eq!(planes2.get(0), 1);
assert_eq!(planes2.get(1), -1);
assert_eq!(planes2.get(2), 0);
assert_eq!(planes2.get(50), 1);
assert_eq!(planes2.get(99), -1);
}
#[test]
fn test_packed_trit_vec_dot_equivalence() {
let mut planes_a = TernaryPlanes::new(64);
let mut planes_b = TernaryPlanes::new(64);
planes_a.set(0, 1);
planes_a.set(1, -1);
planes_a.set(10, 1);
planes_b.set(0, 1);
planes_b.set(1, -1);
planes_b.set(10, 1);
let dot_planes = planes_a.dot(&planes_b);
let packed_a = planes_a.to_packed_trit_vec();
let packed_b = planes_b.to_packed_trit_vec();
let dot_packed = packed_a.dot(&packed_b);
assert_eq!(dot_planes, dot_packed);
assert_eq!(dot_planes, 3); }
}