#[derive(Clone, Debug, PartialEq)]
pub struct PackedTernary {
data: Vec<u64>,
dimension: usize,
}
impl PackedTernary {
pub fn new(data: Vec<u64>, dimension: usize) -> Self {
assert_eq!(
data.len(),
dimension.div_ceil(32),
"PackedTernary: data length {} doesn't match dimension {} (expected {} words)",
data.len(),
dimension,
dimension.div_ceil(32)
);
Self { data, dimension }
}
pub fn data(&self) -> &[u64] {
&self.data
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn zeros(dimension: usize) -> Self {
let num_u64s = dimension.div_ceil(32);
Self {
data: vec![0; num_u64s],
dimension,
}
}
pub fn set(&mut self, idx: usize, val: i8) {
if idx >= self.dimension {
return;
}
let word = idx / 32;
let bit = (idx % 32) * 2;
self.data[word] &= !(0b11u64 << bit);
let bits: u64 = match val {
1 => 0b01,
-1 => 0b10,
_ => 0b00,
};
self.data[word] |= bits << bit;
}
pub fn get(&self, idx: usize) -> i8 {
if idx >= self.dimension {
return 0;
}
let word = idx / 32;
let bit = (idx % 32) * 2;
let bits = (self.data[word] >> bit) & 0b11;
match bits {
0b01 => 1,
0b10 => -1,
_ => 0,
}
}
pub fn nnz(&self) -> usize {
let mut count = 0;
for i in 0..self.dimension {
if self.get(i) != 0 {
count += 1;
}
}
count
}
pub fn memory_bytes(&self) -> usize {
self.data.len() * 8
}
}
#[must_use]
pub fn encode_ternary(values: &[f32], threshold: f32) -> PackedTernary {
let mut result = PackedTernary::zeros(values.len());
for (i, &v) in values.iter().enumerate() {
if v > threshold {
result.set(i, 1);
} else if v < -threshold {
result.set(i, -1);
}
}
result
}
#[inline]
#[must_use]
#[allow(unsafe_code)]
pub fn ternary_dot(a: &PackedTernary, b: &PackedTernary) -> i32 {
assert_eq!(a.dimension, b.dimension);
assert_eq!(a.data.len(), b.data.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("popcnt") {
return unsafe { ternary_dot_popcnt(&a.data, &b.data) };
}
}
#[cfg(target_arch = "aarch64")]
{
return ternary_dot_portable(&a.data, &b.data);
}
#[allow(unreachable_code)]
ternary_dot_portable(&a.data, &b.data)
}
fn ternary_dot_portable(a: &[u64], b: &[u64]) -> i32 {
let mut same_count: u32 = 0;
let mut diff_count: u32 = 0;
const ODD_MASK: u64 = 0x5555555555555555;
const EVEN_MASK: u64 = 0xAAAAAAAAAAAAAAAA;
for (&wa, &wb) in a.iter().zip(b.iter()) {
let pos_a = wa & !((wa & EVEN_MASK) >> 1) & ODD_MASK;
let pos_b = wb & !((wb & EVEN_MASK) >> 1) & ODD_MASK;
let neg_a = !wa & ((wa & EVEN_MASK) >> 1) & ODD_MASK;
let neg_b = !wb & ((wb & EVEN_MASK) >> 1) & ODD_MASK;
let same = (pos_a & pos_b) | (neg_a & neg_b);
let diff = (pos_a & neg_b) | (neg_a & pos_b);
same_count += same.count_ones();
diff_count += diff.count_ones();
}
same_count as i32 - diff_count as i32
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "popcnt")]
#[allow(unsafe_code)]
unsafe fn ternary_dot_popcnt(a: &[u64], b: &[u64]) -> i32 {
use std::arch::x86_64::_popcnt64;
let mut same_count: i64 = 0;
let mut diff_count: i64 = 0;
const ODD_MASK: u64 = 0x5555555555555555;
const EVEN_MASK: u64 = 0xAAAAAAAAAAAAAAAA;
for (&wa, &wb) in a.iter().zip(b.iter()) {
let pos_a = wa & !((wa & EVEN_MASK) >> 1) & ODD_MASK;
let pos_b = wb & !((wb & EVEN_MASK) >> 1) & ODD_MASK;
let neg_a = !wa & ((wa & EVEN_MASK) >> 1) & ODD_MASK;
let neg_b = !wb & ((wb & EVEN_MASK) >> 1) & ODD_MASK;
let same = (pos_a & pos_b) | (neg_a & neg_b);
let diff = (pos_a & neg_b) | (neg_a & pos_b);
same_count += i64::from(_popcnt64(same as i64));
diff_count += i64::from(_popcnt64(diff as i64));
}
(same_count - diff_count) as i32
}
#[inline]
#[must_use]
pub fn asymmetric_dot(query: &[f32], ternary: &PackedTernary) -> f32 {
assert_eq!(query.len(), ternary.dimension);
let mut sum = 0.0f32;
for (i, &q) in query.iter().enumerate() {
let t = ternary.get(i) as f32;
sum += q * t;
}
sum
}
#[must_use]
pub fn ternary_hamming(a: &PackedTernary, b: &PackedTernary) -> u32 {
assert_eq!(a.dimension, b.dimension);
let mut diff_count: u32 = 0;
const ODD_MASK: u64 = 0x5555555555555555;
const EVEN_MASK: u64 = 0xAAAAAAAAAAAAAAAA;
for (&wa, &wb) in a.data.iter().zip(b.data.iter()) {
let nz_a = (wa & ODD_MASK) | ((wa & EVEN_MASK) >> 1);
let nz_b = (wb & ODD_MASK) | ((wb & EVEN_MASK) >> 1);
let both_nz = nz_a & nz_b;
let xor = wa ^ wb;
let diff = (xor & ODD_MASK) | ((xor & EVEN_MASK) >> 1);
diff_count += (diff & both_nz).count_ones();
}
diff_count
}
#[must_use]
pub fn sparsity(v: &PackedTernary) -> f32 {
if v.dimension() == 0 {
return 0.0;
}
let nnz = v.nnz();
1.0 - (nnz as f32 / v.dimension() as f32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode() {
let values = vec![0.5, -0.5, 0.1, -0.1, 0.8, -0.8];
let packed = encode_ternary(&values, 0.3);
assert_eq!(packed.get(0), 1); assert_eq!(packed.get(1), -1); assert_eq!(packed.get(2), 0); assert_eq!(packed.get(3), 0); assert_eq!(packed.get(4), 1); assert_eq!(packed.get(5), -1); }
#[test]
fn test_ternary_dot_same() {
let mut a = PackedTernary::zeros(4);
a.set(0, 1);
a.set(1, -1);
a.set(2, 0);
a.set(3, 1);
let dot = ternary_dot(&a, &a);
assert_eq!(dot, 3);
}
#[test]
fn test_ternary_dot_opposite() {
let mut a = PackedTernary::zeros(4);
let mut b = PackedTernary::zeros(4);
a.set(0, 1);
a.set(1, -1);
b.set(0, -1);
b.set(1, 1);
let dot = ternary_dot(&a, &b);
assert_eq!(dot, -2);
}
#[test]
fn test_ternary_dot_orthogonal() {
let mut a = PackedTernary::zeros(4);
let mut b = PackedTernary::zeros(4);
a.set(0, 1);
a.set(1, 0);
b.set(0, 0);
b.set(1, 1);
let dot = ternary_dot(&a, &b);
assert_eq!(dot, 0);
}
#[test]
fn test_large_vector() {
let values: Vec<f32> = (0..768)
.map(|i| {
let x = (i as f32 / 768.0) - 0.5;
if i % 3 == 0 {
x * 2.0
} else {
x * 0.5
}
})
.collect();
let packed = encode_ternary(&values, 0.3);
assert_eq!(packed.data.len(), 24);
assert_eq!(packed.memory_bytes(), 192);
let dot = ternary_dot(&packed, &packed);
assert_eq!(dot as usize, packed.nnz());
}
#[test]
fn test_asymmetric_dot() {
let mut t = PackedTernary::zeros(4);
t.set(0, 1);
t.set(1, -1);
t.set(2, 0);
t.set(3, 1);
let query = vec![0.5, 0.5, 0.5, 0.5];
let dot = asymmetric_dot(&query, &t);
assert!((dot - 0.5).abs() < 1e-6);
}
#[test]
fn test_hamming() {
let mut a = PackedTernary::zeros(4);
let mut b = PackedTernary::zeros(4);
a.set(0, 1);
a.set(1, -1);
a.set(2, 1);
b.set(0, 1);
b.set(1, 1); b.set(2, -1);
let hamming = ternary_hamming(&a, &b);
assert_eq!(hamming, 2);
}
#[test]
fn test_zeros_all_zero() {
let v = PackedTernary::zeros(100);
for i in 0..100 {
assert_eq!(v.get(i), 0, "index {i} should be 0");
}
assert_eq!(v.nnz(), 0);
}
#[test]
fn test_set_get_all_values() {
let mut v = PackedTernary::zeros(3);
v.set(0, 1);
v.set(1, -1);
v.set(2, 0);
assert_eq!(v.get(0), 1);
assert_eq!(v.get(1), -1);
assert_eq!(v.get(2), 0);
}
#[test]
fn test_set_overwrite() {
let mut v = PackedTernary::zeros(1);
v.set(0, 1);
assert_eq!(v.get(0), 1);
v.set(0, -1);
assert_eq!(v.get(0), -1);
v.set(0, 0);
assert_eq!(v.get(0), 0);
}
#[test]
fn test_set_out_of_bounds_is_noop() {
let mut v = PackedTernary::zeros(4);
v.set(100, 1); }
#[test]
fn test_get_out_of_bounds_returns_zero() {
let v = PackedTernary::zeros(4);
assert_eq!(v.get(4), 0);
assert_eq!(v.get(1000), 0);
}
#[test]
fn test_word_boundary() {
let mut v = PackedTernary::zeros(64);
v.set(31, 1);
v.set(32, -1);
assert_eq!(v.get(31), 1);
assert_eq!(v.get(32), -1);
assert_eq!(v.get(30), 0);
assert_eq!(v.get(33), 0);
}
#[test]
fn test_nnz() {
let mut v = PackedTernary::zeros(10);
v.set(0, 1);
v.set(3, -1);
v.set(7, 1);
assert_eq!(v.nnz(), 3);
}
#[test]
fn test_sparsity_all_zero() {
let v = PackedTernary::zeros(100);
assert!(
(sparsity(&v) - 1.0).abs() < 1e-6,
"all-zero vector has sparsity 1.0"
);
}
#[test]
fn test_sparsity_all_nonzero() {
let values: Vec<f32> = vec![1.0, -1.0, 1.0, -1.0];
let packed = encode_ternary(&values, 0.0);
assert!(
sparsity(&packed).abs() < 1e-6,
"all-nonzero vector has sparsity 0.0"
);
}
#[test]
fn test_sparsity_half() {
let values: Vec<f32> = vec![1.0, 0.0, -1.0, 0.0];
let packed = encode_ternary(&values, 0.5);
assert!((sparsity(&packed) - 0.5).abs() < 1e-6);
}
#[test]
fn test_memory_bytes() {
let v = PackedTernary::zeros(768);
assert_eq!(v.memory_bytes(), 192);
}
#[test]
fn test_encode_ternary_empty() {
let packed = encode_ternary(&[], 0.5);
assert_eq!(packed.dimension, 0);
assert_eq!(packed.nnz(), 0);
}
#[test]
fn test_encode_ternary_at_threshold() {
let packed = encode_ternary(&[0.5, -0.5], 0.5);
assert_eq!(packed.get(0), 0, "value exactly at +threshold should be 0");
assert_eq!(packed.get(1), 0, "value exactly at -threshold should be 0");
}
#[test]
fn test_encode_ternary_zero_threshold() {
let packed = encode_ternary(&[1.0, -1.0, 0.0], 0.0);
assert_eq!(packed.get(0), 1);
assert_eq!(packed.get(1), -1);
assert_eq!(packed.get(2), 0);
}
#[test]
fn test_ternary_dot_all_zeros() {
let a = PackedTernary::zeros(100);
let b = PackedTernary::zeros(100);
assert_eq!(ternary_dot(&a, &b), 0);
}
#[test]
fn test_ternary_dot_mixed_large() {
let mut a = PackedTernary::zeros(64);
let mut b = PackedTernary::zeros(64);
for i in (0..64).step_by(3) {
a.set(i, 1);
}
for i in (0..64).step_by(3) {
b.set(i, 1);
}
let count = (0..64).step_by(3).count() as i32;
assert_eq!(ternary_dot(&a, &b), count);
}
#[test]
fn test_asymmetric_dot_zero_query() {
let mut t = PackedTernary::zeros(3);
t.set(0, 1);
t.set(1, -1);
let query = vec![0.0, 0.0, 0.0];
assert!(asymmetric_dot(&query, &t).abs() < 1e-9);
}
#[test]
fn test_asymmetric_dot_zero_ternary() {
let t = PackedTernary::zeros(3);
let query = vec![1.0, 2.0, 3.0];
assert!(asymmetric_dot(&query, &t).abs() < 1e-9);
}
#[test]
fn test_asymmetric_dot_negative_query() {
let mut t = PackedTernary::zeros(2);
t.set(0, 1);
t.set(1, -1);
let query = vec![-3.0, -4.0];
let result = asymmetric_dot(&query, &t);
assert!((result - 1.0).abs() < 1e-6);
}
#[test]
fn test_hamming_identical() {
let mut v = PackedTernary::zeros(10);
v.set(0, 1);
v.set(3, -1);
v.set(7, 1);
assert_eq!(ternary_hamming(&v, &v), 0, "hamming(v, v) should be 0");
}
#[test]
fn test_hamming_all_opposite() {
let mut a = PackedTernary::zeros(4);
let mut b = PackedTernary::zeros(4);
a.set(0, 1);
a.set(1, -1);
a.set(2, 1);
a.set(3, -1);
b.set(0, -1);
b.set(1, 1);
b.set(2, -1);
b.set(3, 1);
assert_eq!(ternary_hamming(&a, &b), 4);
}
#[test]
fn test_hamming_zeros_ignored() {
let mut a = PackedTernary::zeros(4);
let mut b = PackedTernary::zeros(4);
a.set(0, 1);
b.set(1, -1);
assert_eq!(ternary_hamming(&a, &b), 0);
}
#[test]
fn test_encode_then_dot_consistency() {
let values_a: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) / 10.0).collect();
let values_b: Vec<f32> = (0..32).map(|i| ((i * 3) as f32 - 48.0) / 10.0).collect();
let a = encode_ternary(&values_a, 0.5);
let b = encode_ternary(&values_b, 0.5);
let mut expected = 0i32;
for i in 0..32 {
expected += (a.get(i) as i32) * (b.get(i) as i32);
}
assert_eq!(ternary_dot(&a, &b), expected);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
const DIMS: &[usize] = &[31, 32, 33, 63, 64, 65, 128];
fn arb_packed_ternary() -> impl Strategy<Value = PackedTernary> {
prop::sample::select(DIMS).prop_flat_map(|dim| {
prop::collection::vec(prop::sample::select(&[-1i8, 0, 1][..]), dim).prop_map(
move |vals| {
let mut pt = PackedTernary::zeros(dim);
for (i, &v) in vals.iter().enumerate() {
pt.set(i, v);
}
pt
},
)
})
}
fn arb_packed_ternary_pair() -> impl Strategy<Value = (PackedTernary, PackedTernary)> {
prop::sample::select(DIMS).prop_flat_map(|dim| {
let a = prop::collection::vec(prop::sample::select(&[-1i8, 0, 1][..]), dim).prop_map(
move |vals| {
let mut pt = PackedTernary::zeros(dim);
for (i, &v) in vals.iter().enumerate() {
pt.set(i, v);
}
pt
},
);
let b = prop::collection::vec(prop::sample::select(&[-1i8, 0, 1][..]), dim).prop_map(
move |vals| {
let mut pt = PackedTernary::zeros(dim);
for (i, &v) in vals.iter().enumerate() {
pt.set(i, v);
}
pt
},
);
(a, b)
})
}
proptest! {
#[test]
fn proptest_asymmetric_dot_matches_symmetric_for_ternary_query(
(a, b) in arb_packed_ternary_pair(),
) {
let query: Vec<f32> = (0..a.dimension).map(|i| a.get(i) as f32).collect();
let asym = asymmetric_dot(&query, &b);
let sym = ternary_dot(&a, &b) as f32;
prop_assert!((asym - sym).abs() < 1e-6,
"asymmetric({}) != symmetric({})", asym, sym);
}
#[test]
fn proptest_encode_nnz_le_dimension(
dim in prop::sample::select(DIMS),
threshold in 0.0f32..2.0,
) {
let values: Vec<f32> = (0..dim)
.map(|i| (i as f32 - dim as f32 / 2.0) * 0.1)
.collect();
let packed = encode_ternary(&values, threshold);
prop_assert!(packed.nnz() <= dim,
"nnz {} > dimension {}", packed.nnz(), dim);
}
#[test]
fn proptest_hamming_symmetry((a, b) in arb_packed_ternary_pair()) {
prop_assert_eq!(ternary_hamming(&a, &b), ternary_hamming(&b, &a));
}
#[test]
fn proptest_hamming_self_zero(a in arb_packed_ternary()) {
prop_assert_eq!(ternary_hamming(&a, &a), 0);
}
#[test]
fn proptest_sparsity_range(a in arb_packed_ternary()) {
let s = sparsity(&a);
prop_assert!((0.0..=1.0).contains(&s),
"sparsity {} not in [0, 1]", s);
}
}
}