#[derive(Debug, Clone)]
pub struct TernaryTensor {
pub packed_data: Vec<u8>,
pub scales: Vec<f32>,
pub shape: (usize, usize),
pub block_size: usize,
}
impl TernaryTensor {
pub fn sparsity(&self) -> f32 {
let total_elements = self.shape.0.saturating_mul(self.shape.1);
if total_elements == 0 {
return 0.0;
}
let unpacked = unpack_ternary(&self.packed_data, total_elements);
let zero_count = unpacked.iter().filter(|&&x| x == 0).count();
zero_count as f32 / total_elements as f32
}
pub fn memory_bytes(&self) -> usize {
self.packed_data.len() + self.scales.len() * 4
}
pub fn num_blocks(&self) -> usize {
if self.block_size == 0 {
return 0;
}
let total_elements = self.shape.0.saturating_mul(self.shape.1);
total_elements.saturating_add(self.block_size - 1) / self.block_size
}
}
pub fn pack_ternary(values: &[i8]) -> Vec<u8> {
let num_bytes = (values.len() + 3) / 4;
let mut packed = vec![0u8; num_bytes];
for (i, &val) in values.iter().enumerate() {
let byte_idx = i / 4;
let bit_offset = (i % 4) * 2;
let encoded: u8 = match val {
-1 => 0b00,
0 => 0b01,
1 => 0b10,
v if v < -1 => 0b00, _ => 0b10, };
packed[byte_idx] |= encoded << bit_offset;
}
packed
}
pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec<i8> {
let mut values = Vec::with_capacity(n);
for i in 0..n {
let byte_idx = i / 4;
let bit_offset = (i % 4) * 2;
if byte_idx >= packed.len() {
break;
}
let encoded = (packed[byte_idx] >> bit_offset) & 0b11;
let val = match encoded {
0b00 => -1,
0b01 => 0,
0b10 => 1,
0b11 => 0, _ => unreachable!(),
};
values.push(val);
}
values
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_unpack_ternary() {
let values = vec![-1, 0, 1, -1, 1, 0, 0, 1];
let packed = pack_ternary(&values);
let unpacked = unpack_ternary(&packed, values.len());
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_ternary_single_byte() {
let values = vec![-1, 0, 1, -1];
let packed = pack_ternary(&values);
assert_eq!(packed.len(), 1);
assert_eq!(packed[0], 0b00_10_01_00);
}
#[test]
fn test_pack_ternary_partial_byte() {
let values = vec![-1, 0, 1, -1, 1];
let packed = pack_ternary(&values);
assert_eq!(packed.len(), 2);
}
#[test]
fn test_pack_clamps_invalid_value() {
let values = vec![-5, 0, 2, 3];
let packed = pack_ternary(&values);
let unpacked = unpack_ternary(&packed, 4);
assert_eq!(unpacked[0], -1); assert_eq!(unpacked[1], 0);
assert_eq!(unpacked[2], 1); assert_eq!(unpacked[3], 1); }
#[test]
fn test_ternary_tensor_sparsity() {
let values = vec![0, 1, 0, -1, 0, 0, 1, 0]; let packed = pack_ternary(&values);
let tensor = TernaryTensor {
packed_data: packed,
scales: vec![1.0],
shape: (2, 4),
block_size: 256,
};
let sparsity = tensor.sparsity();
assert!((sparsity - 0.625).abs() < 0.001); }
#[test]
fn test_ternary_tensor_memory() {
let packed = vec![0u8; 64]; let scales = vec![0.5f32; 16];
let tensor = TernaryTensor {
packed_data: packed,
scales,
shape: (128, 256),
block_size: 256,
};
assert_eq!(tensor.memory_bytes(), 64 + 64); }
#[test]
fn test_ternary_tensor_num_blocks() {
let tensor = TernaryTensor {
packed_data: vec![],
scales: vec![],
shape: (256, 256), block_size: 256, };
assert_eq!(tensor.num_blocks(), 256); }
}