1use std::collections::HashMap;
22
23const MIN_OBSERVATIONS: u32 = 20;
26
27#[derive(Debug, Clone)]
32pub struct CausalState {
33 pub id: usize,
34 pub pooled: Vec<u32>,
36 pub histories: Vec<Vec<u8>>,
38}
39
40impl CausalState {
41 fn new(id: usize, alphabet_size: usize) -> Self {
42 Self {
43 id,
44 pooled: vec![0u32; alphabet_size],
45 histories: Vec::new(),
46 }
47 }
48
49 fn total(&self) -> u32 {
50 self.pooled.iter().sum()
51 }
52
53 fn is_empty(&self) -> bool {
54 self.total() == 0 && self.histories.is_empty()
55 }
56
57 fn absorb(&mut self, history: Vec<u8>, counts: &[u32]) {
58 for (i, &c) in counts.iter().enumerate() {
59 self.pooled[i] += c;
60 }
61 self.histories.push(history);
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct CssrResult {
68 pub states: Vec<CausalState>,
70 pub assignment: HashMap<Vec<u8>, usize>,
72 pub alphabet_size: usize,
73 pub max_depth: usize,
74}
75
76#[must_use]
84pub fn ks_reject_homogeneity(counts_a: &[u32], counts_b: &[u32], alpha: f64) -> bool {
85 let n_a: u32 = counts_a.iter().sum();
86 let n_b: u32 = counts_b.iter().sum();
87
88 if n_a < MIN_OBSERVATIONS || n_b < MIN_OBSERVATIONS {
89 return false; }
91
92 let fa = f64::from(n_a);
93 let fb = f64::from(n_b);
94
95 let k = counts_a.len().max(counts_b.len());
97 let mut cum_a = 0u32;
98 let mut cum_b = 0u32;
99 let mut d_max: f64 = 0.0;
100
101 for i in 0..k {
102 cum_a += if i < counts_a.len() { counts_a[i] } else { 0 };
103 cum_b += if i < counts_b.len() { counts_b[i] } else { 0 };
104 let d = (f64::from(cum_a) / fa - f64::from(cum_b) / fb).abs();
105 if d > d_max {
106 d_max = d;
107 }
108 }
109
110 let c_alpha = (-0.5_f64 * alpha.ln()).sqrt();
113 let d_crit = c_alpha * ((fa + fb) / (fa * fb)).sqrt();
114
115 d_max > d_crit
116}
117
118#[must_use]
126pub fn build_suffix_stats(
127 symbols: &[u8],
128 alphabet_size: usize,
129 max_depth: usize,
130) -> HashMap<Vec<u8>, Vec<u32>> {
131 let mut stats: HashMap<Vec<u8>, Vec<u32>> = HashMap::new();
132 let n = symbols.len();
133
134 for depth in 1..=max_depth {
135 for i in depth..n {
136 let next = symbols[i] as usize;
137 if next >= alphabet_size {
138 continue;
139 }
140 let history = symbols[i - depth..i].to_vec();
141 let entry = stats
142 .entry(history)
143 .or_insert_with(|| vec![0u32; alphabet_size]);
144 entry[next] += 1;
145 }
146 }
147
148 stats
149}
150
151#[must_use]
165pub fn run_cssr(symbols: &[u8], alphabet_size: usize, max_depth: usize, alpha: f64) -> CssrResult {
166 let stats = build_suffix_stats(symbols, alphabet_size, max_depth);
167 let mut states: Vec<CausalState> = Vec::new();
168 let mut assignment: HashMap<Vec<u8>, usize> = HashMap::new();
169
170 for depth in 1..=max_depth {
172 let mut histories: Vec<Vec<u8>> =
174 stats.keys().filter(|h| h.len() == depth).cloned().collect();
175 histories.sort(); for history in histories {
178 let hist_counts = &stats[&history];
179 let hist_total: u32 = hist_counts.iter().sum();
180
181 let parent_key: Vec<u8> = if depth > 1 {
183 history[1..].to_vec()
184 } else {
185 vec![]
186 };
187 let parent_state = if depth > 1 {
188 assignment.get(&parent_key).copied()
189 } else {
190 None
191 };
192
193 let target_state: Option<usize> = if let Some(ps_id) = parent_state {
195 if hist_total < MIN_OBSERVATIONS {
197 Some(ps_id) } else {
199 let reject = ks_reject_homogeneity(&states[ps_id].pooled, hist_counts, alpha);
200 if reject {
201 find_compatible(&states, hist_counts, alpha)
203 } else {
204 Some(ps_id)
205 }
206 }
207 } else {
208 if hist_total < MIN_OBSERVATIONS {
210 states.first().map(|s| s.id) } else {
212 find_compatible(&states, hist_counts, alpha)
213 }
214 };
215
216 let sid = target_state.unwrap_or_else(|| {
217 let id = states.len();
218 states.push(CausalState::new(id, alphabet_size));
219 id
220 });
221
222 states[sid].absorb(history.clone(), hist_counts);
223 assignment.insert(history, sid);
224 }
225 }
226
227 merge_pass(&mut states, &mut assignment, alpha);
229
230 let remap = compact(&mut states);
232 for sid in assignment.values_mut() {
233 if let Some(&new_id) = remap.get(sid) {
234 *sid = new_id;
235 }
236 }
237
238 if states.is_empty() {
240 let mut s = CausalState::new(0, alphabet_size);
241 for (h, counts) in &stats {
242 s.absorb(h.clone(), counts);
243 assignment.insert(h.clone(), 0);
244 }
245 states.push(s);
246 }
247
248 CssrResult {
249 states,
250 assignment,
251 alphabet_size,
252 max_depth,
253 }
254}
255
256fn find_compatible(states: &[CausalState], hist_counts: &[u32], alpha: f64) -> Option<usize> {
261 states
262 .iter()
263 .filter(|s| !s.is_empty())
264 .find(|s| !ks_reject_homogeneity(&s.pooled, hist_counts, alpha))
265 .map(|s| s.id)
266}
267
268fn merge_pass(states: &mut Vec<CausalState>, assignment: &mut HashMap<Vec<u8>, usize>, alpha: f64) {
271 let mut changed = true;
272 while changed {
273 changed = false;
274 let n = states.len();
275 'outer: for i in 0..n {
276 for j in (i + 1)..n {
277 if states[i].is_empty() || states[j].is_empty() {
278 continue;
279 }
280 let a = states[i].pooled.clone();
281 let b = states[j].pooled.clone();
282 if !ks_reject_homogeneity(&a, &b, alpha) {
283 let j_hist = states[j].histories.clone();
285 let j_pooled = states[j].pooled.clone();
286 for (k, &c) in j_pooled.iter().enumerate() {
287 states[i].pooled[k] += c;
288 }
289 for h in j_hist {
290 assignment.insert(h.clone(), i);
291 states[i].histories.push(h);
292 }
293 states[j].pooled = vec![0; states[j].pooled.len()];
294 states[j].histories.clear();
295 changed = true;
296 break 'outer;
297 }
298 }
299 }
300 }
301}
302
303fn compact(states: &mut Vec<CausalState>) -> HashMap<usize, usize> {
305 let mut remap: HashMap<usize, usize> = HashMap::new();
306 let mut new_states: Vec<CausalState> = Vec::new();
307 for s in states.drain(..) {
308 if !s.is_empty() {
309 let new_id = new_states.len();
310 remap.insert(s.id, new_id);
311 let mut ns = s;
312 ns.id = new_id;
313 new_states.push(ns);
314 }
315 }
316 *states = new_states;
317 remap
318}
319
320#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn ks_rejects_clearly_different_distributions() {
328 let a = vec![1000u32, 0];
330 let b = vec![0u32, 1000];
331 assert!(ks_reject_homogeneity(&a, &b, 0.001));
332 }
333
334 #[test]
335 fn ks_accepts_identical_distributions() {
336 let a = vec![667u32, 333];
337 let b = vec![670u32, 330];
338 assert!(!ks_reject_homogeneity(&a, &b, 0.001));
339 }
340
341 #[test]
342 fn ks_returns_false_for_small_samples() {
343 let a = vec![5u32, 3];
344 let b = vec![0u32, 8];
345 assert!(!ks_reject_homogeneity(&a, &b, 0.001));
347 }
348
349 #[test]
350 fn build_suffix_stats_counts_correctly() {
351 let seq = vec![0u8, 1, 0, 1, 0, 1, 0, 1];
353 let stats = build_suffix_stats(&seq, 2, 1);
354 let after_0 = &stats[&vec![0u8]];
356 let after_1 = &stats[&vec![1u8]];
357 assert_eq!(after_0[1], 4, "0 → 1 four times in 01010101");
360 assert_eq!(after_1[0], 3, "1 → 0 three times in 01010101");
361 }
362}