#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PackedBinary {
data: Vec<u64>,
dimension: usize,
}
impl PackedBinary {
pub fn new(data: Vec<u64>, dimension: usize) -> Self {
assert_eq!(
data.len(),
dimension.div_ceil(64),
"PackedBinary: data length {} doesn't match dimension {} (expected {} words)",
data.len(),
dimension,
dimension.div_ceil(64)
);
Self { data, dimension }
}
pub fn data(&self) -> &[u64] {
&self.data
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn zeros(dimension: usize) -> Self {
let num_u64s = dimension.div_ceil(64);
Self {
data: vec![0; num_u64s],
dimension,
}
}
pub fn set(&mut self, idx: usize, val: bool) {
if idx >= self.dimension {
return;
}
let word = idx / 64;
let bit = idx % 64;
if val {
self.data[word] |= 1u64 << bit;
} else {
self.data[word] &= !(1u64 << bit);
}
}
pub fn get(&self, idx: usize) -> bool {
if idx >= self.dimension {
return false;
}
let word = idx / 64;
let bit = idx % 64;
((self.data[word] >> bit) & 1) != 0
}
pub fn memory_bytes(&self) -> usize {
self.data.len() * 8
}
}
#[must_use]
pub fn encode_binary(values: &[f32], threshold: f32) -> PackedBinary {
let mut result = PackedBinary::zeros(values.len());
for (i, &v) in values.iter().enumerate() {
if v > threshold {
result.set(i, true);
}
}
result
}
#[inline]
#[must_use]
pub fn binary_hamming(a: &PackedBinary, b: &PackedBinary) -> u32 {
assert_eq!(a.dimension, b.dimension);
a.data
.iter()
.zip(b.data.iter())
.map(|(&wa, &wb)| (wa ^ wb).count_ones())
.sum()
}
#[inline]
#[must_use]
pub fn binary_dot(a: &PackedBinary, b: &PackedBinary) -> u32 {
assert_eq!(a.dimension, b.dimension);
a.data
.iter()
.zip(b.data.iter())
.map(|(&wa, &wb)| (wa & wb).count_ones())
.sum()
}
#[must_use]
pub fn binary_jaccard(a: &PackedBinary, b: &PackedBinary) -> f32 {
let intersection = binary_dot(a, b);
let union = a
.data
.iter()
.zip(b.data.iter())
.map(|(&wa, &wb)| (wa | wb).count_ones())
.sum::<u32>();
if union == 0 {
1.0
} else {
intersection as f32 / union as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_binary_ops() {
let mut a = PackedBinary::zeros(4);
let mut b = PackedBinary::zeros(4);
a.set(0, true);
a.set(1, true);
b.set(1, true);
b.set(2, true);
assert_eq!(binary_hamming(&a, &b), 2); assert_eq!(binary_dot(&a, &b), 1); assert!((binary_jaccard(&a, &b) - 1.0 / 3.0).abs() < 1e-6);
}
#[test]
fn test_zeros() {
let v = PackedBinary::zeros(128);
assert_eq!(v.dimension, 128);
assert_eq!(v.data.len(), 2); for i in 0..128 {
assert!(!v.get(i), "bit {i} should be 0");
}
}
#[test]
fn test_new() {
let data = vec![0xFF_u64]; let v = PackedBinary::new(data, 8);
for i in 0..8 {
assert!(v.get(i), "bit {i} should be 1");
}
}
#[test]
fn test_memory_bytes() {
let v = PackedBinary::zeros(256);
assert_eq!(v.memory_bytes(), 32);
}
#[test]
fn test_set_and_clear() {
let mut v = PackedBinary::zeros(64);
v.set(0, true);
assert!(v.get(0));
v.set(0, false);
assert!(!v.get(0));
}
#[test]
fn test_set_out_of_bounds_is_noop() {
let mut v = PackedBinary::zeros(4);
v.set(100, true); assert!(!v.get(100)); }
#[test]
fn test_get_out_of_bounds() {
let v = PackedBinary::zeros(4);
assert!(!v.get(4));
assert!(!v.get(1000));
}
#[test]
fn test_set_last_bit_in_word() {
let mut v = PackedBinary::zeros(64);
v.set(63, true);
assert!(v.get(63));
assert!(!v.get(62));
}
#[test]
fn test_multi_word_hamming() {
let mut a = PackedBinary::zeros(128);
let mut b = PackedBinary::zeros(128);
a.set(0, true);
b.set(0, true);
a.set(64, true);
b.set(65, true);
assert_eq!(binary_hamming(&a, &b), 2);
}
#[test]
fn test_multi_word_dot() {
let mut a = PackedBinary::zeros(128);
let mut b = PackedBinary::zeros(128);
a.set(0, true);
a.set(64, true);
a.set(65, true);
b.set(0, true);
b.set(64, true);
b.set(100, true);
assert_eq!(binary_dot(&a, &b), 2);
}
#[test]
fn test_multi_word_jaccard() {
let mut a = PackedBinary::zeros(128);
let mut b = PackedBinary::zeros(128);
a.set(0, true);
a.set(64, true);
a.set(65, true);
b.set(0, true);
b.set(64, true);
b.set(100, true);
let j = binary_jaccard(&a, &b);
assert!((j - 0.5).abs() < 1e-6);
}
#[test]
fn test_encode_binary_all_above() {
let v = [1.0, 2.0, 3.0, 4.0];
let packed = encode_binary(&v, 0.0);
for i in 0..4 {
assert!(packed.get(i), "all values > 0, bit {i} should be set");
}
}
#[test]
fn test_encode_binary_all_below() {
let v = [-1.0, -2.0, -3.0, -4.0];
let packed = encode_binary(&v, 0.0);
for i in 0..4 {
assert!(!packed.get(i), "all values <= 0, bit {i} should be clear");
}
}
#[test]
fn test_encode_binary_at_threshold() {
let v = [0.0_f32];
let packed = encode_binary(&v, 0.0);
assert!(!packed.get(0), "value exactly at threshold should be 0");
}
#[test]
fn test_encode_binary_empty() {
let v: [f32; 0] = [];
let packed = encode_binary(&v, 0.0);
assert_eq!(packed.dimension, 0);
}
#[test]
fn test_encode_binary_large() {
let v: Vec<f32> = (0..768)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let packed = encode_binary(&v, 0.0);
assert_eq!(packed.dimension, 768);
assert_eq!(packed.data.len(), 12);
for i in 0..768 {
if i % 2 == 0 {
assert!(packed.get(i), "even index {i} should be 1");
} else {
assert!(!packed.get(i), "odd index {i} should be 0");
}
}
}
#[test]
fn test_hamming_identical() {
let v = encode_binary(&[1.0, -1.0, 1.0, -1.0], 0.0);
assert_eq!(binary_hamming(&v, &v), 0);
}
#[test]
fn test_hamming_complement() {
let a = encode_binary(&[1.0, 1.0, 1.0, 1.0], 0.0);
let b = encode_binary(&[-1.0, -1.0, -1.0, -1.0], 0.0);
assert_eq!(binary_hamming(&a, &b), 4);
}
#[test]
fn test_dot_self() {
let v = encode_binary(&[1.0, -1.0, 1.0, -1.0, 1.0], 0.0);
assert_eq!(binary_dot(&v, &v), 3);
}
#[test]
fn test_jaccard_identical() {
let v = encode_binary(&[1.0, -1.0, 1.0], 0.0);
let j = binary_jaccard(&v, &v);
assert!((j - 1.0).abs() < 1e-6, "jaccard(v, v) should be 1.0");
}
#[test]
fn test_jaccard_disjoint() {
let a = encode_binary(&[1.0, -1.0], 0.0); let b = encode_binary(&[-1.0, 1.0], 0.0); let j = binary_jaccard(&a, &b);
assert!(j.abs() < 1e-6);
}
#[test]
fn test_jaccard_both_empty() {
let a = PackedBinary::zeros(4);
let b = PackedBinary::zeros(4);
assert!((binary_jaccard(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn test_encode_binary_nonzero_threshold() {
let v = [0.1, 0.5, 0.9, 1.5];
let packed = encode_binary(&v, 0.5);
assert!(!packed.get(0)); assert!(!packed.get(1)); assert!(packed.get(2)); assert!(packed.get(3)); }
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
const DIMS: &[usize] = &[31, 32, 33, 63, 64, 65, 128];
fn arb_packed_binary() -> impl Strategy<Value = PackedBinary> {
prop::sample::select(DIMS).prop_flat_map(|dim| {
prop::collection::vec(any::<bool>(), dim).prop_map(move |bits| {
let mut pb = PackedBinary::zeros(dim);
for (i, &b) in bits.iter().enumerate() {
pb.set(i, b);
}
pb
})
})
}
fn arb_packed_binary_pair() -> impl Strategy<Value = (PackedBinary, PackedBinary)> {
prop::sample::select(DIMS).prop_flat_map(|dim| {
let a = prop::collection::vec(any::<bool>(), dim).prop_map(move |bits| {
let mut pb = PackedBinary::zeros(dim);
for (i, &b) in bits.iter().enumerate() {
pb.set(i, b);
}
pb
});
let b = prop::collection::vec(any::<bool>(), dim).prop_map(move |bits| {
let mut pb = PackedBinary::zeros(dim);
for (i, &b) in bits.iter().enumerate() {
pb.set(i, b);
}
pb
});
(a, b)
})
}
proptest! {
#[test]
fn proptest_hamming_symmetry((a, b) in arb_packed_binary_pair()) {
prop_assert_eq!(binary_hamming(&a, &b), binary_hamming(&b, &a));
}
#[test]
fn proptest_hamming_self_zero(a in arb_packed_binary()) {
prop_assert_eq!(binary_hamming(&a, &a), 0);
}
#[test]
fn proptest_hamming_range((a, b) in arb_packed_binary_pair()) {
let h = binary_hamming(&a, &b);
prop_assert!(h as usize <= a.dimension,
"hamming {} > dimension {}", h, a.dimension);
}
#[test]
fn proptest_jaccard_range((a, b) in arb_packed_binary_pair()) {
let j = binary_jaccard(&a, &b);
prop_assert!((0.0..=1.0).contains(&j),
"jaccard {} not in [0, 1]", j);
}
#[test]
fn proptest_jaccard_self(a in arb_packed_binary()) {
let has_bits = a.data.iter().any(|&w| w != 0);
let j = binary_jaccard(&a, &a);
if has_bits {
prop_assert!((j - 1.0).abs() < 1e-6,
"jaccard(a, a) = {} but expected 1.0 for non-zero vector", j);
} else {
prop_assert!((j - 1.0).abs() < 1e-6,
"jaccard(zero, zero) = {} but expected 1.0", j);
}
}
#[test]
fn proptest_dot_commutativity((a, b) in arb_packed_binary_pair()) {
prop_assert_eq!(binary_dot(&a, &b), binary_dot(&b, &a));
}
#[test]
fn proptest_encode_deterministic(
dim in prop::sample::select(DIMS),
threshold in -10.0f32..10.0,
) {
let values: Vec<f32> = (0..dim).map(|i| (i as f32 - dim as f32 / 2.0) * 0.1).collect();
let a = encode_binary(&values, threshold);
let b = encode_binary(&values, threshold);
prop_assert_eq!(a, b);
}
}
}