use ndarray::{Array1, ArrayView1};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BitVec {
words: Vec<u64>,
len: usize,
}
impl BitVec {
pub fn zeros(len: usize) -> Self {
let words = vec![0u64; len.div_ceil(64)];
Self { words, len }
}
pub fn ones(len: usize) -> Self {
let mut bv = Self::zeros(len);
for i in 0..len {
bv.set(i, true);
}
bv
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn get(&self, i: usize) -> bool {
assert!(
i < self.len,
"BitVec::get index {i} out of bounds {}",
self.len
);
let (w, b) = (i / 64, i % 64);
(self.words[w] >> b) & 1 == 1
}
#[inline]
pub fn set(&mut self, i: usize, v: bool) {
assert!(
i < self.len,
"BitVec::set index {i} out of bounds {}",
self.len
);
let (w, b) = (i / 64, i % 64);
if v {
self.words[w] |= 1u64 << b;
} else {
self.words[w] &= !(1u64 << b);
}
}
pub fn count_ones(&self) -> usize {
self.words.iter().map(|w| w.count_ones() as usize).sum()
}
pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
(0..self.len).filter(move |&i| self.get(i))
}
pub fn clear(&mut self) {
for w in self.words.iter_mut() {
*w = 0;
}
}
}
#[derive(Debug, Clone)]
pub struct SparseAtomCode {
pub active_mask: BitVec,
pub weights: Vec<f64>,
}
impl SparseAtomCode {
pub fn empty(k_atoms: usize) -> Self {
Self {
active_mask: BitVec::zeros(k_atoms),
weights: vec![0.0; k_atoms],
}
}
pub fn k_atoms(&self) -> usize {
self.weights.len()
}
pub fn n_active(&self) -> usize {
self.active_mask.count_ones()
}
pub fn active_weight_sum(&self) -> f64 {
self.active_mask.iter_ones().map(|k| self.weights[k]).sum()
}
pub fn assign(&mut self, k: usize, w: f64) {
assert!(k < self.k_atoms());
self.active_mask.set(k, true);
self.weights[k] = w;
}
pub fn deactivate(&mut self, k: usize) {
assert!(k < self.k_atoms());
self.active_mask.set(k, false);
self.weights[k] = 0.0;
}
pub fn weights_view(&self) -> ArrayView1<'_, f64> {
ArrayView1::from(&self.weights[..])
}
pub fn effective_weights(&self) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.k_atoms());
for k in self.active_mask.iter_ones() {
out[k] = self.weights[k];
}
out
}
}
#[derive(Debug, Clone)]
pub struct SparseAtomCodes {
codes: Vec<SparseAtomCode>,
k_atoms: usize,
}
impl SparseAtomCodes {
pub fn empty(n_obs: usize, k_atoms: usize) -> Self {
let codes = (0..n_obs).map(|_| SparseAtomCode::empty(k_atoms)).collect();
Self { codes, k_atoms }
}
pub fn n_obs(&self) -> usize {
self.codes.len()
}
pub fn k_atoms(&self) -> usize {
self.k_atoms
}
pub fn row(&self, n: usize) -> &SparseAtomCode {
&self.codes[n]
}
pub fn row_mut(&mut self, n: usize) -> &mut SparseAtomCode {
&mut self.codes[n]
}
pub fn iter(&self) -> impl Iterator<Item = &SparseAtomCode> {
self.codes.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut SparseAtomCode> {
self.codes.iter_mut()
}
pub fn weights_matrix(&self) -> ndarray::Array2<f64> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = ndarray::Array2::<f64>::zeros((n, k));
for n_idx in 0..n {
let code = &self.codes[n_idx];
for kk in code.active_mask.iter_ones() {
out[[n_idx, kk]] = code.weights[kk];
}
}
out
}
pub fn coactivation(&self, a: usize, b: usize) -> CoactivationStats {
assert!(
a < self.k_atoms && b < self.k_atoms,
"SparseAtomCodes::coactivation: atoms ({a}, {b}) out of range K={}",
self.k_atoms
);
let n_obs = self.n_obs();
let mut n_a = 0usize;
let mut n_b = 0usize;
let mut n_joint = 0usize;
for code in &self.codes {
let on_a = code.active_mask.get(a);
let on_b = code.active_mask.get(b);
n_a += usize::from(on_a);
n_b += usize::from(on_b);
n_joint += usize::from(on_a && on_b);
}
let cond = |joint: usize, marg: usize| {
if marg == 0 {
0.0
} else {
joint as f64 / marg as f64
}
};
let lift = if n_a == 0 || n_b == 0 || n_obs == 0 {
0.0
} else {
(n_joint as f64 * n_obs as f64) / (n_a as f64 * n_b as f64)
};
CoactivationStats {
n_obs,
n_a,
n_b,
n_joint,
p_a_given_b: cond(n_joint, n_b),
p_b_given_a: cond(n_joint, n_a),
lift,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct CoactivationStats {
pub n_obs: usize,
pub n_a: usize,
pub n_b: usize,
pub n_joint: usize,
pub p_a_given_b: f64,
pub p_b_given_a: f64,
pub lift: f64,
}
impl CoactivationStats {
pub fn dependence(&self) -> f64 {
self.p_a_given_b.min(self.p_b_given_a)
}
pub fn absorption_asymmetry(&self) -> f64 {
(self.p_a_given_b - self.p_b_given_a).abs()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bitvec_basic() {
let mut bv = BitVec::zeros(70);
assert_eq!(bv.len(), 70);
assert!(!bv.get(5));
bv.set(5, true);
bv.set(64, true);
assert!(bv.get(5));
assert!(bv.get(64));
assert_eq!(bv.count_ones(), 2);
let ones: Vec<usize> = bv.iter_ones().collect();
assert_eq!(ones, vec![5, 64]);
bv.set(5, false);
assert_eq!(bv.count_ones(), 1);
}
#[test]
fn sparse_code_assign() {
let mut c = SparseAtomCode::empty(8);
c.assign(2, 0.7);
c.assign(5, 0.3);
assert_eq!(c.n_active(), 2);
assert!((c.active_weight_sum() - 1.0).abs() < 1e-12);
c.deactivate(2);
assert_eq!(c.n_active(), 1);
assert_eq!(c.weights[2], 0.0);
}
#[test]
fn codes_matrix_roundtrip() {
let mut codes = SparseAtomCodes::empty(3, 4);
codes.row_mut(0).assign(1, 0.5);
codes.row_mut(2).assign(3, 0.9);
let m = codes.weights_matrix();
assert_eq!(m[[0, 1]], 0.5);
assert_eq!(m[[2, 3]], 0.9);
assert_eq!(m[[1, 0]], 0.0);
}
#[test]
fn coactivation_separates_independent_shattered_and_absorbed() {
let n = 100usize;
let mut codes = SparseAtomCodes::empty(n, 4);
for row in 0..n {
if row % 2 == 0 {
codes.row_mut(row).assign(0, 1.0);
}
if row % 5 == 0 {
codes.row_mut(row).assign(1, 1.0);
}
if row % 4 == 0 || row % 2 == 1 {
codes.row_mut(row).assign(2, 1.0);
}
if row % 4 == 0 {
codes.row_mut(row).assign(3, 1.0);
}
}
let indep = codes.coactivation(0, 1);
assert_eq!(indep.n_joint, 10);
assert!((indep.p_a_given_b - 0.5).abs() < 1e-12);
assert!((indep.p_b_given_a - 0.2).abs() < 1e-12);
assert!((indep.lift - 1.0).abs() < 1e-12);
assert!(indep.dependence() < 0.25);
let nested = codes.coactivation(2, 3);
assert!((nested.p_a_given_b - 1.0).abs() < 1e-12);
assert!(nested.p_b_given_a < 0.5);
assert!(nested.absorption_asymmetry() > 0.6);
let mut dup = SparseAtomCodes::empty(n, 2);
for row in (0..n).step_by(3) {
dup.row_mut(row).assign(0, 1.0);
dup.row_mut(row).assign(1, 1.0);
}
let shat = dup.coactivation(0, 1);
assert!((shat.dependence() - 1.0).abs() < 1e-12);
assert!(shat.absorption_asymmetry() < 1e-12);
let empty = SparseAtomCodes::empty(4, 2).coactivation(0, 1);
assert_eq!(empty.dependence(), 0.0);
assert_eq!(empty.lift, 0.0);
}
}