Skip to main content

epsilon_engine/
cssr.rs

1//! Pillar: III. PACR field: Γ.
2//!
3//! Causal State Splitting Reconstruction (CSSR) algorithm.
4//!
5//! Ref: Shalizi & Crutchfield (2001), "Computational Mechanics: Pattern and
6//! Prediction, Structure and Simplicity", J. Stat. Phys. 104(3-4):817-879.
7//!
8//! # Algorithm outline
9//!
10//! 1. Build suffix statistics: for every observed history `h` of length
11//!    `1..=max_depth`, count how often each symbol follows `h`.
12//! 2. Process histories depth-first (L=1 first).
13//!    For each history `h` of length L:
14//!    a. Find parent history `h[1..]` and its assigned causal state S.
15//!    b. KS-test: is `h`'s conditional distribution homogeneous with S?
16//!       • Yes → assign `h` to S (add `h`'s counts to S's pool).
17//!       • No  → search other states for a homogeneous match; else new state.
18//! 3. Merge pass: collapse pairs of states that are now indistinguishable.
19//! 4. Compact (remove empty states, re-index).
20
21use std::collections::HashMap;
22
23/// Minimum per-history observations before eligibility for KS testing.
24/// Histories with fewer observations are skipped (assigned to parent state).
25const MIN_OBSERVATIONS: u32 = 20;
26
27// ── Data structures ───────────────────────────────────────────────────────────
28
29/// A causal state: an equivalence class of histories sharing the same
30/// conditional future distribution.
31#[derive(Debug, Clone)]
32pub struct CausalState {
33    pub id: usize,
34    /// Pooled next-symbol counts across all histories assigned to this state.
35    pub pooled: Vec<u32>,
36    /// Histories (as symbol sequences) assigned to this state.
37    pub histories: Vec<Vec<u8>>,
38}
39
40impl CausalState {
41    fn new(id: usize, alphabet_size: usize) -> Self {
42        Self {
43            id,
44            pooled: vec![0u32; alphabet_size],
45            histories: Vec::new(),
46        }
47    }
48
49    fn total(&self) -> u32 {
50        self.pooled.iter().sum()
51    }
52
53    fn is_empty(&self) -> bool {
54        self.total() == 0 && self.histories.is_empty()
55    }
56
57    fn absorb(&mut self, history: Vec<u8>, counts: &[u32]) {
58        for (i, &c) in counts.iter().enumerate() {
59            self.pooled[i] += c;
60        }
61        self.histories.push(history);
62    }
63}
64
65/// Output of a single CSSR run.
66#[derive(Debug, Clone)]
67pub struct CssrResult {
68    /// Final causal states (non-empty, re-indexed 0..k).
69    pub states: Vec<CausalState>,
70    /// Maps every observed history to its state index.
71    pub assignment: HashMap<Vec<u8>, usize>,
72    pub alphabet_size: usize,
73    pub max_depth: usize,
74}
75
76// ── Two-sample Kolmogorov–Smirnov test ───────────────────────────────────────
77
78/// Two-sample KS test for discrete distributions.
79///
80/// Returns `true` (reject homogeneity) when the empirical CDFs differ by more
81/// than the `alpha`-level critical value.  Returns `false` when either sample
82/// is too small (`< MIN_OBSERVATIONS`) — conservative (assume homogeneous).
83#[must_use]
84pub fn ks_reject_homogeneity(counts_a: &[u32], counts_b: &[u32], alpha: f64) -> bool {
85    let n_a: u32 = counts_a.iter().sum();
86    let n_b: u32 = counts_b.iter().sum();
87
88    if n_a < MIN_OBSERVATIONS || n_b < MIN_OBSERVATIONS {
89        return false; // insufficient data — do not split
90    }
91
92    let fa = f64::from(n_a);
93    let fb = f64::from(n_b);
94
95    // Maximum absolute CDF difference over the discrete alphabet.
96    let k = counts_a.len().max(counts_b.len());
97    let mut cum_a = 0u32;
98    let mut cum_b = 0u32;
99    let mut d_max: f64 = 0.0;
100
101    for i in 0..k {
102        cum_a += if i < counts_a.len() { counts_a[i] } else { 0 };
103        cum_b += if i < counts_b.len() { counts_b[i] } else { 0 };
104        let d = (f64::from(cum_a) / fa - f64::from(cum_b) / fb).abs();
105        if d > d_max {
106            d_max = d;
107        }
108    }
109
110    // Asymptotic critical value: D_crit = c_α × sqrt((n_A + n_B) / (n_A × n_B)).
111    // c_α = sqrt(-0.5 × ln α).
112    let c_alpha = (-0.5_f64 * alpha.ln()).sqrt();
113    let d_crit = c_alpha * ((fa + fb) / (fa * fb)).sqrt();
114
115    d_max > d_crit
116}
117
118// ── Suffix statistics ─────────────────────────────────────────────────────────
119
120/// Build per-suffix next-symbol count vectors for all depths `1..=max_depth`.
121///
122/// Memory: O(N × `max_depth`) worst-case, but histories with identical byte
123/// sequences share one entry.  For typical processes this is
124/// `O(|A|^max_depth` × |A|) entries, capped by N.
125#[must_use]
126pub fn build_suffix_stats(
127    symbols: &[u8],
128    alphabet_size: usize,
129    max_depth: usize,
130) -> HashMap<Vec<u8>, Vec<u32>> {
131    let mut stats: HashMap<Vec<u8>, Vec<u32>> = HashMap::new();
132    let n = symbols.len();
133
134    for depth in 1..=max_depth {
135        for i in depth..n {
136            let next = symbols[i] as usize;
137            if next >= alphabet_size {
138                continue;
139            }
140            let history = symbols[i - depth..i].to_vec();
141            let entry = stats
142                .entry(history)
143                .or_insert_with(|| vec![0u32; alphabet_size]);
144            entry[next] += 1;
145        }
146    }
147
148    stats
149}
150
151// ── Core CSSR ─────────────────────────────────────────────────────────────────
152
153/// Run CSSR on a discrete symbol sequence.
154///
155/// # Arguments
156/// * `symbols`       — observed symbol sequence (values `0..alphabet_size`)
157/// * `alphabet_size` — `|A|`
158/// * `max_depth`     — maximum history length `L`
159/// * `alpha`         — KS significance level (e.g. `0.001`)
160///
161/// # Returns
162///
163/// [`CssrResult`] with the inferred causal states and history → state map.
164#[must_use]
165pub fn run_cssr(symbols: &[u8], alphabet_size: usize, max_depth: usize, alpha: f64) -> CssrResult {
166    let stats = build_suffix_stats(symbols, alphabet_size, max_depth);
167    let mut states: Vec<CausalState> = Vec::new();
168    let mut assignment: HashMap<Vec<u8>, usize> = HashMap::new();
169
170    // Process histories depth-by-depth (L=1 first).
171    for depth in 1..=max_depth {
172        // Collect all observed histories of this depth.
173        let mut histories: Vec<Vec<u8>> =
174            stats.keys().filter(|h| h.len() == depth).cloned().collect();
175        histories.sort(); // deterministic order
176
177        for history in histories {
178            let hist_counts = &stats[&history];
179            let hist_total: u32 = hist_counts.iter().sum();
180
181            // Parent history: drop the oldest (first) symbol.
182            let parent_key: Vec<u8> = if depth > 1 {
183                history[1..].to_vec()
184            } else {
185                vec![]
186            };
187            let parent_state = if depth > 1 {
188                assignment.get(&parent_key).copied()
189            } else {
190                None
191            };
192
193            // --- Assign this history to a causal state ---
194            let target_state: Option<usize> = if let Some(ps_id) = parent_state {
195                // Is this history homogeneous with its parent's state?
196                if hist_total < MIN_OBSERVATIONS {
197                    Some(ps_id) // too few obs — keep with parent
198                } else {
199                    let reject = ks_reject_homogeneity(&states[ps_id].pooled, hist_counts, alpha);
200                    if reject {
201                        // Heterogeneous: find another compatible state.
202                        find_compatible(&states, hist_counts, alpha)
203                    } else {
204                        Some(ps_id)
205                    }
206                }
207            } else {
208                // Depth-1 with no parent: find any compatible state.
209                if hist_total < MIN_OBSERVATIONS {
210                    states.first().map(|s| s.id) // assign to first available
211                } else {
212                    find_compatible(&states, hist_counts, alpha)
213                }
214            };
215
216            let sid = target_state.unwrap_or_else(|| {
217                let id = states.len();
218                states.push(CausalState::new(id, alphabet_size));
219                id
220            });
221
222            states[sid].absorb(history.clone(), hist_counts);
223            assignment.insert(history, sid);
224        }
225    }
226
227    // Merge pass: collapse any two states that have become indistinguishable.
228    merge_pass(&mut states, &mut assignment, alpha);
229
230    // Compact: remove empty states and re-index.
231    let remap = compact(&mut states);
232    for sid in assignment.values_mut() {
233        if let Some(&new_id) = remap.get(sid) {
234            *sid = new_id;
235        }
236    }
237
238    // Ensure at least one state.
239    if states.is_empty() {
240        let mut s = CausalState::new(0, alphabet_size);
241        for (h, counts) in &stats {
242            s.absorb(h.clone(), counts);
243            assignment.insert(h.clone(), 0);
244        }
245        states.push(s);
246    }
247
248    CssrResult {
249        states,
250        assignment,
251        alphabet_size,
252        max_depth,
253    }
254}
255
256// ── Helpers ───────────────────────────────────────────────────────────────────
257
258/// Find the first existing state whose pooled distribution is homogeneous with
259/// `hist_counts` at significance `alpha`.  Returns `None` if no match found.
260fn find_compatible(states: &[CausalState], hist_counts: &[u32], alpha: f64) -> Option<usize> {
261    states
262        .iter()
263        .filter(|s| !s.is_empty())
264        .find(|s| !ks_reject_homogeneity(&s.pooled, hist_counts, alpha))
265        .map(|s| s.id)
266}
267
268/// Merge pass: repeatedly scan for pairs of states whose pooled distributions
269/// are homogeneous; merge the larger-index into the smaller-index.
270fn merge_pass(states: &mut Vec<CausalState>, assignment: &mut HashMap<Vec<u8>, usize>, alpha: f64) {
271    let mut changed = true;
272    while changed {
273        changed = false;
274        let n = states.len();
275        'outer: for i in 0..n {
276            for j in (i + 1)..n {
277                if states[i].is_empty() || states[j].is_empty() {
278                    continue;
279                }
280                let a = states[i].pooled.clone();
281                let b = states[j].pooled.clone();
282                if !ks_reject_homogeneity(&a, &b, alpha) {
283                    // Merge j → i.
284                    let j_hist = states[j].histories.clone();
285                    let j_pooled = states[j].pooled.clone();
286                    for (k, &c) in j_pooled.iter().enumerate() {
287                        states[i].pooled[k] += c;
288                    }
289                    for h in j_hist {
290                        assignment.insert(h.clone(), i);
291                        states[i].histories.push(h);
292                    }
293                    states[j].pooled = vec![0; states[j].pooled.len()];
294                    states[j].histories.clear();
295                    changed = true;
296                    break 'outer;
297                }
298            }
299        }
300    }
301}
302
303/// Remove empty states and return an old→new index map.
304fn compact(states: &mut Vec<CausalState>) -> HashMap<usize, usize> {
305    let mut remap: HashMap<usize, usize> = HashMap::new();
306    let mut new_states: Vec<CausalState> = Vec::new();
307    for s in states.drain(..) {
308        if !s.is_empty() {
309            let new_id = new_states.len();
310            remap.insert(s.id, new_id);
311            let mut ns = s;
312            ns.id = new_id;
313            new_states.push(ns);
314        }
315    }
316    *states = new_states;
317    remap
318}
319
320// ── Tests ─────────────────────────────────────────────────────────────────────
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn ks_rejects_clearly_different_distributions() {
328        // [1000, 0] vs [0, 1000] — maximally different
329        let a = vec![1000u32, 0];
330        let b = vec![0u32, 1000];
331        assert!(ks_reject_homogeneity(&a, &b, 0.001));
332    }
333
334    #[test]
335    fn ks_accepts_identical_distributions() {
336        let a = vec![667u32, 333];
337        let b = vec![670u32, 330];
338        assert!(!ks_reject_homogeneity(&a, &b, 0.001));
339    }
340
341    #[test]
342    fn ks_returns_false_for_small_samples() {
343        let a = vec![5u32, 3];
344        let b = vec![0u32, 8];
345        // n < MIN_OBSERVATIONS → conservative (false)
346        assert!(!ks_reject_homogeneity(&a, &b, 0.001));
347    }
348
349    #[test]
350    fn build_suffix_stats_counts_correctly() {
351        // Sequence "01010101" with |A|=2, max_depth=1.
352        let seq = vec![0u8, 1, 0, 1, 0, 1, 0, 1];
353        let stats = build_suffix_stats(&seq, 2, 1);
354        // After "0": always 1 follows (3 times) — except last "0" ends sequence.
355        let after_0 = &stats[&vec![0u8]];
356        let after_1 = &stats[&vec![1u8]];
357        // "0" appears at positions 0,2,4,6 — each followed by "1" → 4 times.
358        // "1" appears at positions 1,3,5,7 — positions 1,3,5 followed by "0" (pos 7 is last) → 3 times.
359        assert_eq!(after_0[1], 4, "0 → 1 four times in 01010101");
360        assert_eq!(after_1[0], 3, "1 → 0 three times in 01010101");
361    }
362}