use ndarray::Array1;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BitVec {
words: Vec<u64>,
len: usize,
}
impl BitVec {
pub fn zeros(len: usize) -> Self {
let words = vec![0u64; len.div_ceil(64)];
Self { words, len }
}
pub fn ones(len: usize) -> Self {
let mut bv = Self::zeros(len);
for i in 0..len {
bv.set(i, true);
}
bv
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn get(&self, i: usize) -> bool {
assert!(
i < self.len,
"BitVec::get index {i} out of bounds {}",
self.len
);
let (w, b) = (i / 64, i % 64);
(self.words[w] >> b) & 1 == 1
}
#[inline]
pub fn set(&mut self, i: usize, v: bool) {
assert!(
i < self.len,
"BitVec::set index {i} out of bounds {}",
self.len
);
let (w, b) = (i / 64, i % 64);
if v {
self.words[w] |= 1u64 << b;
} else {
self.words[w] &= !(1u64 << b);
}
}
pub fn count_ones(&self) -> usize {
self.words.iter().map(|w| w.count_ones() as usize).sum()
}
pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
(0..self.len).filter(move |&i| self.get(i))
}
pub fn clear(&mut self) {
for w in self.words.iter_mut() {
*w = 0;
}
}
}
#[derive(Debug, Clone)]
pub struct SparseAtomCode {
pub active_mask: BitVec,
pub weights: Vec<f64>,
}
impl SparseAtomCode {
pub fn empty(k_atoms: usize) -> Self {
Self {
active_mask: BitVec::zeros(k_atoms),
weights: vec![0.0; k_atoms],
}
}
pub fn k_atoms(&self) -> usize {
self.weights.len()
}
pub fn n_active(&self) -> usize {
self.active_mask.count_ones()
}
pub fn active_weight_sum(&self) -> f64 {
self.active_mask.iter_ones().map(|k| self.weights[k]).sum()
}
pub fn assign(&mut self, k: usize, w: f64) {
assert!(k < self.k_atoms());
self.active_mask.set(k, true);
self.weights[k] = w;
}
pub fn deactivate(&mut self, k: usize) {
assert!(k < self.k_atoms());
self.active_mask.set(k, false);
self.weights[k] = 0.0;
}
pub fn effective_weights(&self) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(self.k_atoms());
for k in self.active_mask.iter_ones() {
out[k] = self.weights[k];
}
out
}
}
#[derive(Debug, Clone)]
pub struct SparseAtomCodes {
codes: Vec<SparseAtomCode>,
k_atoms: usize,
}
impl SparseAtomCodes {
pub fn empty(n_obs: usize, k_atoms: usize) -> Self {
let codes = (0..n_obs).map(|_| SparseAtomCode::empty(k_atoms)).collect();
Self { codes, k_atoms }
}
pub fn n_obs(&self) -> usize {
self.codes.len()
}
pub fn k_atoms(&self) -> usize {
self.k_atoms
}
pub fn row(&self, n: usize) -> &SparseAtomCode {
&self.codes[n]
}
pub fn row_mut(&mut self, n: usize) -> &mut SparseAtomCode {
&mut self.codes[n]
}
pub fn iter(&self) -> impl Iterator<Item = &SparseAtomCode> {
self.codes.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut SparseAtomCode> {
self.codes.iter_mut()
}
pub fn weights_matrix(&self) -> ndarray::Array2<f64> {
let n = self.n_obs();
let k = self.k_atoms();
let mut out = ndarray::Array2::<f64>::zeros((n, k));
for n_idx in 0..n {
let code = &self.codes[n_idx];
for kk in code.active_mask.iter_ones() {
out[[n_idx, kk]] = code.weights[kk];
}
}
out
}
pub fn coactivation(&self, a: usize, b: usize) -> CoactivationStats {
assert!(
a < self.k_atoms && b < self.k_atoms,
"SparseAtomCodes::coactivation: atoms ({a}, {b}) out of range K={}",
self.k_atoms
);
let n_obs = self.n_obs();
let mut n_a = 0usize;
let mut n_b = 0usize;
let mut n_joint = 0usize;
for code in &self.codes {
let on_a = code.active_mask.get(a);
let on_b = code.active_mask.get(b);
n_a += usize::from(on_a);
n_b += usize::from(on_b);
n_joint += usize::from(on_a && on_b);
}
let cond = |joint: usize, marg: usize| {
if marg == 0 {
0.0
} else {
joint as f64 / marg as f64
}
};
let lift = if n_a == 0 || n_b == 0 || n_obs == 0 {
0.0
} else {
(n_joint as f64 * n_obs as f64) / (n_a as f64 * n_b as f64)
};
CoactivationStats {
n_obs,
n_a,
n_b,
n_joint,
p_a_given_b: cond(n_joint, n_b),
p_b_given_a: cond(n_joint, n_a),
lift,
weight_correlation: self.weight_codependence(a, b),
}
}
pub fn weight_codependence(&self, a: usize, b: usize) -> f64 {
assert!(
a < self.k_atoms && b < self.k_atoms,
"SparseAtomCodes::weight_codependence: atoms ({a}, {b}) out of range K={}",
self.k_atoms
);
let mut wa = Vec::new();
let mut wb = Vec::new();
for code in &self.codes {
if code.active_mask.get(a) && code.active_mask.get(b) {
wa.push(code.weights[a]);
wb.push(code.weights[b]);
}
}
let m = wa.len();
if m < 2 {
return 0.0;
}
let inv = 1.0 / m as f64;
let mean_a: f64 = wa.iter().sum::<f64>() * inv;
let mean_b: f64 = wb.iter().sum::<f64>() * inv;
let mut cov = 0.0_f64;
let mut var_a = 0.0_f64;
let mut var_b = 0.0_f64;
for i in 0..m {
let da = wa[i] - mean_a;
let db = wb[i] - mean_b;
cov += da * db;
var_a += da * da;
var_b += db * db;
}
if !(var_a > 0.0 && var_b > 0.0) {
return 0.0;
}
let rho = cov / (var_a.sqrt() * var_b.sqrt());
rho.clamp(-1.0, 1.0)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct CoactivationStats {
pub n_obs: usize,
pub n_a: usize,
pub n_b: usize,
pub n_joint: usize,
pub p_a_given_b: f64,
pub p_b_given_a: f64,
pub lift: f64,
pub weight_correlation: f64,
}
impl CoactivationStats {
pub fn dependence(&self) -> f64 {
self.p_a_given_b.min(self.p_b_given_a)
}
pub fn absorption_asymmetry(&self) -> f64 {
(self.p_a_given_b - self.p_b_given_a).abs()
}
pub fn fusion_evidence(&self) -> f64 {
self.dependence() * self.weight_correlation.abs()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bitvec_basic() {
let mut bv = BitVec::zeros(70);
assert_eq!(bv.len(), 70);
assert!(!bv.get(5));
bv.set(5, true);
bv.set(64, true);
assert!(bv.get(5));
assert!(bv.get(64));
assert_eq!(bv.count_ones(), 2);
let ones: Vec<usize> = bv.iter_ones().collect();
assert_eq!(ones, vec![5, 64]);
bv.set(5, false);
assert_eq!(bv.count_ones(), 1);
}
#[test]
fn sparse_code_assign() {
let mut c = SparseAtomCode::empty(8);
c.assign(2, 0.7);
c.assign(5, 0.3);
assert_eq!(c.n_active(), 2);
assert!((c.active_weight_sum() - 1.0).abs() < 1e-12);
c.deactivate(2);
assert_eq!(c.n_active(), 1);
assert_eq!(c.weights[2], 0.0);
}
#[test]
fn codes_matrix_roundtrip() {
let mut codes = SparseAtomCodes::empty(3, 4);
codes.row_mut(0).assign(1, 0.5);
codes.row_mut(2).assign(3, 0.9);
let m = codes.weights_matrix();
assert_eq!(m[[0, 1]], 0.5);
assert_eq!(m[[2, 3]], 0.9);
assert_eq!(m[[1, 0]], 0.0);
}
#[test]
fn coactivation_separates_independent_shattered_and_absorbed() {
let n = 100usize;
let mut codes = SparseAtomCodes::empty(n, 4);
for row in 0..n {
if row % 2 == 0 {
codes.row_mut(row).assign(0, 1.0);
}
if row % 5 == 0 {
codes.row_mut(row).assign(1, 1.0);
}
if row % 4 == 0 || row % 2 == 1 {
codes.row_mut(row).assign(2, 1.0);
}
if row % 4 == 0 {
codes.row_mut(row).assign(3, 1.0);
}
}
let indep = codes.coactivation(0, 1);
assert_eq!(indep.n_joint, 10);
assert!((indep.p_a_given_b - 0.5).abs() < 1e-12);
assert!((indep.p_b_given_a - 0.2).abs() < 1e-12);
assert!((indep.lift - 1.0).abs() < 1e-12);
assert!(indep.dependence() < 0.25);
let nested = codes.coactivation(2, 3);
assert!((nested.p_a_given_b - 1.0).abs() < 1e-12);
assert!(nested.p_b_given_a < 0.5);
assert!(nested.absorption_asymmetry() > 0.6);
let mut dup = SparseAtomCodes::empty(n, 2);
for row in (0..n).step_by(3) {
dup.row_mut(row).assign(0, 1.0);
dup.row_mut(row).assign(1, 1.0);
}
let shat = dup.coactivation(0, 1);
assert!((shat.dependence() - 1.0).abs() < 1e-12);
assert!(shat.absorption_asymmetry() < 1e-12);
let empty = SparseAtomCodes::empty(4, 2).coactivation(0, 1);
assert_eq!(empty.dependence(), 0.0);
assert_eq!(empty.lift, 0.0);
}
#[test]
fn fusion_criterion_distinguishes_shattered_from_independent_coactive() {
let n = 120usize;
let mut shattered = SparseAtomCodes::empty(n, 2);
for row in 0..n {
let t = (row as f64 + 0.5) / n as f64;
shattered.row_mut(row).assign(0, t);
shattered.row_mut(row).assign(1, 1.0 - t);
}
let shat = shattered.coactivation(0, 1);
assert!(
(shat.dependence() - 1.0).abs() < 1e-12,
"shattered pair shares support: dependence={}",
shat.dependence()
);
assert!(
shat.weight_correlation < -0.99,
"shattered family's partition-of-unity weights are anti-correlated on \
the joint support: weight_correlation={}",
shat.weight_correlation
);
assert!(
shat.fusion_evidence() > 0.95,
"fusion evidence must FIRE on a planted shatter: {}",
shat.fusion_evidence()
);
let mut independent = SparseAtomCodes::empty(n, 2);
for row in 0..n {
let x = row as f64;
let wa = 0.5 + 0.4 * (2.0 * std::f64::consts::PI * x / 7.0).sin();
let wb = 0.5 + 0.4 * (2.0 * std::f64::consts::PI * x / 11.0 + 1.3).cos();
independent.row_mut(row).assign(0, wa);
independent.row_mut(row).assign(1, wb);
}
let indep = independent.coactivation(0, 1);
assert!(
(indep.dependence() - 1.0).abs() < 1e-12,
"independent pair was constructed with identical support: dependence={}",
indep.dependence()
);
assert!(
indep.weight_correlation.abs() < 0.3,
"independent co-active atoms have ~uncorrelated weights: \
weight_correlation={}",
indep.weight_correlation
);
assert!(
shat.fusion_evidence() > 3.0 * indep.fusion_evidence().max(1e-6),
"fusion evidence must rank the shattered pair far above the independent \
pair: shattered={}, independent={}",
shat.fusion_evidence(),
indep.fusion_evidence()
);
let mut tiny = SparseAtomCodes::empty(3, 2);
tiny.row_mut(0).assign(0, 1.0);
tiny.row_mut(0).assign(1, 1.0);
assert_eq!(
tiny.weight_codependence(0, 1),
0.0,
"a single jointly-active row carries no amplitude correlation"
);
}
}