use std::collections::HashMap;
const MIN_OBSERVATIONS: u32 = 20;
#[derive(Debug, Clone)]
pub struct CausalState {
pub id: usize,
pub pooled: Vec<u32>,
pub histories: Vec<Vec<u8>>,
}
impl CausalState {
fn new(id: usize, alphabet_size: usize) -> Self {
Self {
id,
pooled: vec![0u32; alphabet_size],
histories: Vec::new(),
}
}
fn total(&self) -> u32 {
self.pooled.iter().sum()
}
fn is_empty(&self) -> bool {
self.total() == 0 && self.histories.is_empty()
}
fn absorb(&mut self, history: Vec<u8>, counts: &[u32]) {
for (i, &c) in counts.iter().enumerate() {
self.pooled[i] += c;
}
self.histories.push(history);
}
}
#[derive(Debug, Clone)]
pub struct CssrResult {
pub states: Vec<CausalState>,
pub assignment: HashMap<Vec<u8>, usize>,
pub alphabet_size: usize,
pub max_depth: usize,
}
#[must_use]
pub fn ks_reject_homogeneity(counts_a: &[u32], counts_b: &[u32], alpha: f64) -> bool {
let n_a: u32 = counts_a.iter().sum();
let n_b: u32 = counts_b.iter().sum();
if n_a < MIN_OBSERVATIONS || n_b < MIN_OBSERVATIONS {
return false; }
let fa = f64::from(n_a);
let fb = f64::from(n_b);
let k = counts_a.len().max(counts_b.len());
let mut cum_a = 0u32;
let mut cum_b = 0u32;
let mut d_max: f64 = 0.0;
for i in 0..k {
cum_a += if i < counts_a.len() { counts_a[i] } else { 0 };
cum_b += if i < counts_b.len() { counts_b[i] } else { 0 };
let d = (f64::from(cum_a) / fa - f64::from(cum_b) / fb).abs();
if d > d_max {
d_max = d;
}
}
let c_alpha = (-0.5_f64 * alpha.ln()).sqrt();
let d_crit = c_alpha * ((fa + fb) / (fa * fb)).sqrt();
d_max > d_crit
}
#[must_use]
pub fn build_suffix_stats(
symbols: &[u8],
alphabet_size: usize,
max_depth: usize,
) -> HashMap<Vec<u8>, Vec<u32>> {
let mut stats: HashMap<Vec<u8>, Vec<u32>> = HashMap::new();
let n = symbols.len();
for depth in 1..=max_depth {
for i in depth..n {
let next = symbols[i] as usize;
if next >= alphabet_size {
continue;
}
let history = symbols[i - depth..i].to_vec();
let entry = stats
.entry(history)
.or_insert_with(|| vec![0u32; alphabet_size]);
entry[next] += 1;
}
}
stats
}
#[must_use]
pub fn run_cssr(symbols: &[u8], alphabet_size: usize, max_depth: usize, alpha: f64) -> CssrResult {
let stats = build_suffix_stats(symbols, alphabet_size, max_depth);
let mut states: Vec<CausalState> = Vec::new();
let mut assignment: HashMap<Vec<u8>, usize> = HashMap::new();
for depth in 1..=max_depth {
let mut histories: Vec<Vec<u8>> =
stats.keys().filter(|h| h.len() == depth).cloned().collect();
histories.sort();
for history in histories {
let hist_counts = &stats[&history];
let hist_total: u32 = hist_counts.iter().sum();
let parent_key: Vec<u8> = if depth > 1 {
history[1..].to_vec()
} else {
vec![]
};
let parent_state = if depth > 1 {
assignment.get(&parent_key).copied()
} else {
None
};
let target_state: Option<usize> = if let Some(ps_id) = parent_state {
if hist_total < MIN_OBSERVATIONS {
Some(ps_id) } else {
let reject = ks_reject_homogeneity(&states[ps_id].pooled, hist_counts, alpha);
if reject {
find_compatible(&states, hist_counts, alpha)
} else {
Some(ps_id)
}
}
} else {
if hist_total < MIN_OBSERVATIONS {
states.first().map(|s| s.id) } else {
find_compatible(&states, hist_counts, alpha)
}
};
let sid = target_state.unwrap_or_else(|| {
let id = states.len();
states.push(CausalState::new(id, alphabet_size));
id
});
states[sid].absorb(history.clone(), hist_counts);
assignment.insert(history, sid);
}
}
merge_pass(&mut states, &mut assignment, alpha);
let remap = compact(&mut states);
for sid in assignment.values_mut() {
if let Some(&new_id) = remap.get(sid) {
*sid = new_id;
}
}
if states.is_empty() {
let mut s = CausalState::new(0, alphabet_size);
for (h, counts) in &stats {
s.absorb(h.clone(), counts);
assignment.insert(h.clone(), 0);
}
states.push(s);
}
CssrResult {
states,
assignment,
alphabet_size,
max_depth,
}
}
fn find_compatible(states: &[CausalState], hist_counts: &[u32], alpha: f64) -> Option<usize> {
states
.iter()
.filter(|s| !s.is_empty())
.find(|s| !ks_reject_homogeneity(&s.pooled, hist_counts, alpha))
.map(|s| s.id)
}
fn merge_pass(states: &mut Vec<CausalState>, assignment: &mut HashMap<Vec<u8>, usize>, alpha: f64) {
let mut changed = true;
while changed {
changed = false;
let n = states.len();
'outer: for i in 0..n {
for j in (i + 1)..n {
if states[i].is_empty() || states[j].is_empty() {
continue;
}
let a = states[i].pooled.clone();
let b = states[j].pooled.clone();
if !ks_reject_homogeneity(&a, &b, alpha) {
let j_hist = states[j].histories.clone();
let j_pooled = states[j].pooled.clone();
for (k, &c) in j_pooled.iter().enumerate() {
states[i].pooled[k] += c;
}
for h in j_hist {
assignment.insert(h.clone(), i);
states[i].histories.push(h);
}
states[j].pooled = vec![0; states[j].pooled.len()];
states[j].histories.clear();
changed = true;
break 'outer;
}
}
}
}
}
fn compact(states: &mut Vec<CausalState>) -> HashMap<usize, usize> {
let mut remap: HashMap<usize, usize> = HashMap::new();
let mut new_states: Vec<CausalState> = Vec::new();
for s in states.drain(..) {
if !s.is_empty() {
let new_id = new_states.len();
remap.insert(s.id, new_id);
let mut ns = s;
ns.id = new_id;
new_states.push(ns);
}
}
*states = new_states;
remap
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ks_rejects_clearly_different_distributions() {
let a = vec![1000u32, 0];
let b = vec![0u32, 1000];
assert!(ks_reject_homogeneity(&a, &b, 0.001));
}
#[test]
fn ks_accepts_identical_distributions() {
let a = vec![667u32, 333];
let b = vec![670u32, 330];
assert!(!ks_reject_homogeneity(&a, &b, 0.001));
}
#[test]
fn ks_returns_false_for_small_samples() {
let a = vec![5u32, 3];
let b = vec![0u32, 8];
assert!(!ks_reject_homogeneity(&a, &b, 0.001));
}
#[test]
fn build_suffix_stats_counts_correctly() {
let seq = vec![0u8, 1, 0, 1, 0, 1, 0, 1];
let stats = build_suffix_stats(&seq, 2, 1);
let after_0 = &stats[&vec![0u8]];
let after_1 = &stats[&vec![1u8]];
assert_eq!(after_0[1], 4, "0 → 1 four times in 01010101");
assert_eq!(after_1[0], 3, "1 → 0 three times in 01010101");
}
}