use ndarray::{Array1, ArrayView1};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BitVec {
words: Vec<u64>,
len: usize,
}
impl BitVec {
pub fn zeros(len: usize) -> Self {
let words = vec![0u64; len.div_ceil(64)];
Self { words, len }
}
pub fn ones(len: usize) -> Self {
let mut bv = Self::zeros(len);
for i in 0..len {
bv.set(i, true);
}
bv
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn get(&self, i: usize) -> bool {
assert!(
i < self.len,
"BitVec::get index {i} out of bounds {}",
self.len
);
let (w, b) = (i / 64, i % 64);
(self.words[w] >> b) & 1 == 1
}
#[inline]
pub fn set(&mut self, i: usize, v: bool) {
assert!(
i < self.len,
"BitVec::set index {i} out of bounds {}",
self.len
);
let (w, b) = (i / 64, i % 64);
if v {
self.words[w] |= 1u64 << b;
} else {
self.words[w] &= !(1u64 << b);
}
}
pub fn count_ones(&self) -> usize {
self.words.iter().map(|w| w.count_ones() as usize).sum()
}
pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
(0..self.len).filter(move |&i| self.get(i))
}
pub fn clear(&mut self) {
for w in self.words.iter_mut() {
*w = 0;
}
}
}
#[derive(Debug, Clone)]
pub struct SparseAtomCode {
pub active_mask: BitVec,
pub weights: Vec<f64>,
}
impl SparseAtomCode {
pub fn empty(k_atoms: usize) -> Self {
Self {
active_mask: BitVec::zeros(k_atoms),
weights: vec![0.0; k_atoms],
}
}
pub fn k_atoms(&self) -> usize {
self.weights.len()
}
pub fn n_active(&self) -> usize {
self.active_mask.count_ones()
}
pub fn active_weight_sum(&self) -> f64 {
self.active_mask.iter_ones().map(|k| self.weights[k]).sum()
}
pub fn assign(&mut self, k: usize, w: f64) {
assert!(k < self.k_atoms());
self.active_mask.set(k, true);
self.weights[k] = w;
}
pub fn deactivate(&mut self, k: usize) {
assert!(k < self.k_atoms());
self.active_mask.set(k, false);
self.weights[k] = 0.0;
}
pub fn weights_view(&self) -> ArrayView1<'_, f64> {
ArrayView1::from(&self.weights[..])
}
pub fn effective_weights(&self) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.k_atoms());
for k in self.active_mask.iter_ones() {
out[k] = self.weights[k];
}
out
}
}
#[derive(Debug, Clone)]
pub struct SparseAtomCodes {
codes: Vec<SparseAtomCode>,
k_atoms: usize,
}
impl SparseAtomCodes {
pub fn empty(n_obs: usize, k_atoms: usize) -> Self {
let codes = (0..n_obs).map(|_| SparseAtomCode::empty(k_atoms)).collect();
Self { codes, k_atoms }
}
pub fn n_obs(&self) -> usize {
self.codes.len()
}
pub fn k_atoms(&self) -> usize {
self.k_atoms
}
pub fn row(&self, n: usize) -> &SparseAtomCode {
&self.codes[n]
}
pub fn row_mut(&mut self, n: usize) -> &mut SparseAtomCode {
&mut self.codes[n]
}
pub fn iter(&self) -> impl Iterator<Item = &SparseAtomCode> {
self.codes.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut SparseAtomCode> {
self.codes.iter_mut()
}
pub fn weights_matrix(&self) -> ndarray::Array2<f64> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = ndarray::Array2::<f64>::zeros((n, k));
for n_idx in 0..n {
let code = &self.codes[n_idx];
for kk in code.active_mask.iter_ones() {
out[[n_idx, kk]] = code.weights[kk];
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bitvec_basic() {
let mut bv = BitVec::zeros(70);
assert_eq!(bv.len(), 70);
assert!(!bv.get(5));
bv.set(5, true);
bv.set(64, true);
assert!(bv.get(5));
assert!(bv.get(64));
assert_eq!(bv.count_ones(), 2);
let ones: Vec<usize> = bv.iter_ones().collect();
assert_eq!(ones, vec![5, 64]);
bv.set(5, false);
assert_eq!(bv.count_ones(), 1);
}
#[test]
fn sparse_code_assign() {
let mut c = SparseAtomCode::empty(8);
c.assign(2, 0.7);
c.assign(5, 0.3);
assert_eq!(c.n_active(), 2);
assert!((c.active_weight_sum() - 1.0).abs() < 1e-12);
c.deactivate(2);
assert_eq!(c.n_active(), 1);
assert_eq!(c.weights[2], 0.0);
}
#[test]
fn codes_matrix_roundtrip() {
let mut codes = SparseAtomCodes::empty(3, 4);
codes.row_mut(0).assign(1, 0.5);
codes.row_mut(2).assign(3, 0.9);
let m = codes.weights_matrix();
assert_eq!(m[[0, 1]], 0.5);
assert_eq!(m[[2, 3]], 0.9);
assert_eq!(m[[1, 0]], 0.0);
}
}