use rand::{Rng, RngExt};
use crate::simd;
#[derive(Clone, Debug)]
pub struct BitStreamTensor {
pub data: Vec<u64>,
pub length: usize,
}
impl BitStreamTensor {
pub fn from_words(data: Vec<u64>, length: usize) -> Self {
assert!(length > 0, "bitstream length must be > 0");
Self { data, length }
}
pub fn xor_inplace(&mut self, other: &BitStreamTensor) {
assert_eq!(
self.length, other.length,
"Bitstream lengths must match for XOR."
);
for (a, b) in self.data.iter_mut().zip(other.data.iter()) {
*a ^= *b;
}
}
pub fn xor(&self, other: &BitStreamTensor) -> BitStreamTensor {
assert_eq!(
self.length, other.length,
"Bitstream lengths must match for XOR."
);
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| a ^ b)
.collect();
BitStreamTensor {
data,
length: self.length,
}
}
pub fn rotate_right(&mut self, shift: usize) {
if self.length == 0 || shift.is_multiple_of(self.length) {
return;
}
let mut bits = unpack(self);
bits.rotate_right(shift % self.length);
*self = pack(&bits);
}
pub fn hamming_distance(&self, other: &BitStreamTensor) -> f32 {
assert_eq!(
self.length, other.length,
"Bitstream lengths must match for Hamming distance."
);
let xor_count: u64 = crate::simd::fused_xor_popcount_dispatch(&self.data, &other.data);
xor_count as f32 / self.length as f32
}
pub fn bundle(vectors: &[&BitStreamTensor]) -> BitStreamTensor {
assert!(!vectors.is_empty(), "Cannot bundle zero vectors.");
let length = vectors[0].length;
let words = vectors[0].data.len();
if vectors.len() == 1 {
return vectors[0].clone();
}
let mut data = vec![0u64; words];
if vectors.len() == 3 {
for (i, item) in data.iter_mut().enumerate().take(words) {
let a = vectors[0].data[i];
let b = vectors[1].data[i];
let c = vectors[2].data[i];
*item = (a & b) | (b & c) | (a & c);
}
} else {
let threshold = vectors.len() / 2;
for (i, item) in data.iter_mut().enumerate().take(words) {
for bit in 0..64 {
let mut count = 0;
for v in vectors {
if (v.data[i] >> bit) & 1 == 1 {
count += 1;
}
}
if count > threshold {
*item |= 1u64 << bit;
}
}
}
}
BitStreamTensor { data, length }
}
}
pub fn pack(bits: &[u8]) -> BitStreamTensor {
let length = bits.len();
let words = length.div_ceil(64);
let mut data = vec![0_u64; words];
for (idx, bit) in bits.iter().copied().enumerate() {
if bit != 0 {
data[idx / 64] |= 1_u64 << (idx % 64);
}
}
BitStreamTensor { data, length }
}
pub fn pack_fast(bits: &[u8]) -> BitStreamTensor {
let length = bits.len();
let words = length.div_ceil(64);
let mut data = vec![0_u64; words];
for (word_idx, word) in data.iter_mut().enumerate() {
let base = word_idx * 64;
let chunk = &bits[base..std::cmp::min(base + 64, length)];
for (byte_idx, byte_chunk) in chunk.chunks(8).enumerate() {
let mut packed_byte: u8 = 0;
for (bit_idx, &bit) in byte_chunk.iter().enumerate() {
packed_byte |= u8::from(bit != 0) << bit_idx;
}
*word |= (packed_byte as u64) << (byte_idx * 8);
}
}
BitStreamTensor { data, length }
}
pub fn unpack(tensor: &BitStreamTensor) -> Vec<u8> {
let mut bits = vec![0_u8; tensor.length];
for (idx, bit) in bits.iter_mut().enumerate().take(tensor.length) {
let word = tensor.data[idx / 64];
*bit = ((word >> (idx % 64)) & 1) as u8;
}
bits
}
pub fn bitwise_and(a: &BitStreamTensor, b: &BitStreamTensor) -> BitStreamTensor {
assert_eq!(
a.length, b.length,
"Bitstream lengths must match for bitwise AND."
);
assert_eq!(
a.data.len(),
b.data.len(),
"Packed bitstream shapes must match for bitwise AND."
);
let data = a
.data
.iter()
.zip(b.data.iter())
.map(|(lhs, rhs)| lhs & rhs)
.collect();
BitStreamTensor {
data,
length: a.length,
}
}
pub fn swar_popcount_word(mut x: u64) -> u64 {
x = x.wrapping_sub((x >> 1) & 0x5555_5555_5555_5555);
x = (x & 0x3333_3333_3333_3333) + ((x >> 2) & 0x3333_3333_3333_3333);
x = (x + (x >> 4)) & 0x0f0f_0f0f_0f0f_0f0f;
x.wrapping_mul(0x0101_0101_0101_0101) >> 56
}
pub fn popcount_words_portable(data: &[u64]) -> u64 {
data.iter().copied().map(swar_popcount_word).sum()
}
pub fn popcount(tensor: &BitStreamTensor) -> u64 {
popcount_words_portable(&tensor.data)
}
pub fn bernoulli_stream<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u8> {
let p = prob.clamp(0.0, 1.0);
let mut out = vec![0_u8; length];
for bit in &mut out {
*bit = if rng.random::<f64>() < p { 1 } else { 0 };
}
out
}
pub fn bernoulli_packed<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
let p = prob.clamp(0.0, 1.0);
let words = length.div_ceil(64);
let mut data = vec![0_u64; words];
for (word_idx, word) in data.iter_mut().enumerate() {
let bits_in_word = std::cmp::min(64, length.saturating_sub(word_idx * 64));
for bit in 0..bits_in_word {
if rng.random::<f64>() < p {
*word |= 1_u64 << bit;
}
}
}
data
}
pub fn bernoulli_packed_fast<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
let words = length.div_ceil(64);
if prob <= 0.0 {
return vec![0_u64; words];
}
if prob >= 1.0 {
let mut data = vec![u64::MAX; words];
let trailing = length % 64;
if trailing > 0 {
data[words - 1] = (1_u64 << trailing) - 1;
}
return data;
}
let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
let mut data = vec![0_u64; words];
let mut buf = [0_u8; 64];
for (word_idx, word) in data.iter_mut().enumerate() {
let bits_in_word = std::cmp::min(64, length.saturating_sub(word_idx * 64));
rng.fill(&mut buf[..bits_in_word]);
for (bit, &rb) in buf[..bits_in_word].iter().enumerate() {
if rb < threshold {
*word |= 1_u64 << bit;
}
}
}
data
}
pub fn bernoulli_packed_simd<R: Rng + ?Sized>(prob: f64, length: usize, rng: &mut R) -> Vec<u64> {
let words = length.div_ceil(64);
if prob <= 0.0 {
return vec![0_u64; words];
}
if prob >= 1.0 {
let mut data = vec![u64::MAX; words];
let trailing = length % 64;
if trailing > 0 {
data[words - 1] = (1_u64 << trailing) - 1;
}
return data;
}
let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
let mut data = vec![0_u64; words];
let full_words = length / 64;
let mut buf = [0_u8; 1024];
let mut chunks = data[..full_words].chunks_exact_mut(16);
for w_chunk in chunks.by_ref() {
rng.fill(&mut buf);
crate::simd::bernoulli_compare_batch_1024(&buf, threshold, w_chunk);
}
for word in chunks.into_remainder() {
let mut small_buf = [0_u8; 64];
rng.fill(&mut small_buf);
*word = simd_bernoulli_compare_exposed(&small_buf, threshold);
}
if full_words < words {
let remaining = length - full_words * 64;
rng.fill(&mut buf[..remaining]);
let mut tail = 0_u64;
for (bit, &rb) in buf[..remaining].iter().enumerate() {
if rb < threshold {
tail |= 1_u64 << bit;
}
}
data[full_words] = tail;
}
data
}
pub fn encode_and_popcount<R: Rng + ?Sized>(
weight_words: &[u64],
prob: f64,
length: usize,
rng: &mut R,
) -> u64 {
if prob <= 0.0 {
return 0;
}
if prob >= 1.0 {
let full_words = length / 64;
let mut total = 0_u64;
for w in weight_words.iter().take(full_words) {
total += w.count_ones() as u64;
}
let trailing = length % 64;
if trailing > 0 && full_words < weight_words.len() {
let mask = (1_u64 << trailing) - 1;
total += (weight_words[full_words] & mask).count_ones() as u64;
}
return total;
}
let threshold = (prob.clamp(0.0, 1.0) * 256.0) as u8;
let full_words = length / 64;
let mut total = 0_u64;
let mut buf = [0_u8; 1024]; let mut chunks = weight_words[..full_words].chunks_exact(16);
let mut encoded_batch = [0_u64; 16];
for w_chunk in chunks.by_ref() {
rng.fill(&mut buf);
crate::simd::bernoulli_compare_batch_1024(&buf, threshold, &mut encoded_batch);
for (i, &w_word) in w_chunk.iter().enumerate() {
total += (encoded_batch[i] & w_word).count_ones() as u64;
}
}
for &w_word in chunks.remainder() {
let mut small_buf = [0_u8; 64];
rng.fill(&mut small_buf);
let encoded = simd_bernoulli_compare_exposed(&small_buf, threshold);
total += (encoded & w_word).count_ones() as u64;
}
let remaining = length.saturating_sub(full_words * 64);
if remaining > 0 && full_words < weight_words.len() {
rng.fill(&mut buf[..remaining]);
let mut encoded = 0_u64;
for (bit, &rb) in buf[..remaining].iter().enumerate() {
if rb < threshold {
encoded |= 1_u64 << bit;
}
}
total += (encoded & weight_words[full_words]).count_ones() as u64;
}
total
}
#[inline]
pub fn simd_bernoulli_compare_exposed(buf: &[u8], threshold: u8) -> u64 {
debug_assert!(buf.len() >= 64, "buffer must contain at least 64 bytes");
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") {
return unsafe { simd::avx512::bernoulli_compare_avx512(buf, threshold) };
}
if is_x86_feature_detected!("avx2") {
let lo = unsafe { simd::avx2::bernoulli_compare_avx2(&buf[0..32], threshold) };
let hi = unsafe { simd::avx2::bernoulli_compare_avx2(&buf[32..64], threshold) };
return (lo as u64) | ((hi as u64) << 32);
}
}
let mut mask = 0_u64;
for (bit, &rb) in buf.iter().take(64).enumerate() {
if rb < threshold {
mask |= 1_u64 << bit;
}
}
mask
}
pub fn encode_matrix_prob_to_packed<R: Rng + ?Sized>(
values: &[f64],
rows: usize,
cols: usize,
length: usize,
words: usize,
rng: &mut R,
) -> Vec<Vec<u64>> {
let mut packed = Vec::with_capacity(rows * cols);
for value in values.iter().take(rows * cols) {
let mut row = bernoulli_packed_simd(*value, length, rng);
row.resize(words, 0);
packed.push(row);
}
packed
}
#[cfg(test)]
mod tests {
use super::{
bernoulli_packed, bernoulli_packed_fast, bernoulli_packed_simd, bernoulli_stream,
bitwise_and, encode_and_popcount, pack, pack_fast, popcount, unpack,
};
#[test]
fn pack_unpack_roundtrip() {
let bits = vec![1, 0, 1, 1, 0, 1, 0, 0, 1];
let packed = pack(&bits);
let unpacked = unpack(&packed);
assert_eq!(bits, unpacked);
}
#[test]
fn pack_fast_matches_pack() {
let cases = [0_usize, 1, 7, 8, 9, 63, 64, 65, 127, 128, 256, 1025];
for length in cases {
let bits: Vec<u8> = (0..length).map(|i| ((i * 7 + 3) % 2) as u8).collect();
let slow = pack(&bits);
let fast = pack_fast(&bits);
assert_eq!(fast.length, slow.length);
assert_eq!(fast.data, slow.data, "Mismatch at length={length}");
}
}
#[test]
fn pack_fast_roundtrip() {
let bits: Vec<u8> = (0..2048).map(|i| ((i * 5 + 1) % 2) as u8).collect();
let packed = pack_fast(&bits);
let unpacked = unpack(&packed);
assert_eq!(bits, unpacked);
}
#[test]
fn and_and_popcount() {
let a = pack(&[1, 0, 1, 1, 0, 0, 1, 1]);
let b = pack(&[1, 1, 1, 0, 0, 1, 1, 0]);
let c = bitwise_and(&a, &b);
assert_eq!(unpack(&c), vec![1, 0, 1, 0, 0, 0, 1, 0]);
assert_eq!(popcount(&c), 3);
}
#[test]
fn bernoulli_packed_matches_stream_then_pack() {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let prob = 0.35;
let length = 200;
let mut rng1 = ChaCha8Rng::seed_from_u64(999);
let stream = bernoulli_stream(prob, length, &mut rng1);
let packed_via_stream = pack(&stream).data;
let mut rng2 = ChaCha8Rng::seed_from_u64(999);
let packed_direct = bernoulli_packed(prob, length, &mut rng2);
assert_eq!(
packed_via_stream, packed_direct,
"bernoulli_packed must produce bit-identical output"
);
}
#[test]
fn bernoulli_packed_fast_statistics() {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let prob = 0.35;
let length = 10_000;
let mut rng = ChaCha8Rng::seed_from_u64(42);
let packed = bernoulli_packed_fast(prob, length, &mut rng);
let count: u64 = packed.iter().map(|w| w.count_ones() as u64).sum();
let measured = count as f64 / length as f64;
assert!(
(measured - prob).abs() < 0.03,
"Expected ~{prob}, got {measured}"
);
}
#[test]
fn bernoulli_packed_fast_deterministic() {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let mut rng1 = ChaCha8Rng::seed_from_u64(99);
let a = bernoulli_packed_fast(0.5, 512, &mut rng1);
let mut rng2 = ChaCha8Rng::seed_from_u64(99);
let b = bernoulli_packed_fast(0.5, 512, &mut rng2);
assert_eq!(a, b, "Same seed must produce identical output");
}
#[test]
fn bernoulli_packed_simd_statistics() {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let prob = 0.35;
let length = 10_000;
let mut rng = ChaCha8Rng::seed_from_u64(1337);
let packed = bernoulli_packed_simd(prob, length, &mut rng);
let count: u64 = packed.iter().map(|w| w.count_ones() as u64).sum();
let measured = count as f64 / length as f64;
assert!(
(measured - prob).abs() < 0.03,
"Expected ~{prob}, got {measured}"
);
}
#[test]
fn bernoulli_packed_simd_deterministic() {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let mut rng1 = ChaCha8Rng::seed_from_u64(2026);
let a = bernoulli_packed_simd(0.5, 1024, &mut rng1);
let mut rng2 = ChaCha8Rng::seed_from_u64(2026);
let b = bernoulli_packed_simd(0.5, 1024, &mut rng2);
assert_eq!(a, b, "Same seed must produce identical output");
}
#[test]
fn encode_and_popcount_matches_materialized() {
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256PlusPlus;
let prob = 0.41;
let lengths = [63_usize, 64, 65, 1003, 1024];
for length in lengths {
let words = length.div_ceil(64);
let weights: Vec<u64> = (0..words)
.map(|i| (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ 0xA5A5_A5A5_5A5A_5A5A)
.collect();
let mut rng1 = Xoshiro256PlusPlus::seed_from_u64(2026);
let fused = encode_and_popcount(&weights, prob, length, &mut rng1);
let mut rng2 = Xoshiro256PlusPlus::seed_from_u64(2026);
let encoded = bernoulli_packed_simd(prob, length, &mut rng2);
let expected: u64 = encoded
.iter()
.zip(weights.iter())
.map(|(&e, &w)| (e & w).count_ones() as u64)
.sum();
assert_eq!(fused, expected, "Mismatch at length={length}");
}
}
}