use crate::complexity::counts_to_probs;
use crate::cssr::{run_cssr, CausalState};
pub trait BootstrapBackend {
fn resample_and_estimate(&self, data: &[u8], b: usize) -> Vec<f64>;
}
#[derive(Debug, Clone)]
pub struct CpuBootstrap {
pub max_depth: usize,
pub alpha: f64,
pub alphabet_size: usize,
}
impl CpuBootstrap {
#[must_use]
pub fn new(max_depth: usize, alpha: f64, alphabet_size: usize) -> Self {
Self {
max_depth,
alpha,
alphabet_size,
}
}
}
impl Default for CpuBootstrap {
fn default() -> Self {
Self {
max_depth: 4,
alpha: 0.001,
alphabet_size: 2,
}
}
}
impl BootstrapBackend for CpuBootstrap {
fn resample_and_estimate(&self, data: &[u8], b: usize) -> Vec<f64> {
if data.is_empty() || b == 0 {
return Vec::new();
}
let result = run_cssr(data, self.alphabet_size, self.max_depth, self.alpha);
if result.states.is_empty() {
return Vec::new();
}
let pi = empirical_pi(&result.states, data, result.max_depth);
let mut rng = Xorshift64::new(0xdead_beef_cafe_1234);
let mut out = Vec::with_capacity(2 * b);
for _ in 0..b {
let boot_states = resample_states(&result.states, &mut rng);
let (c, h) = compute_ch(&boot_states, &pi);
out.push(c);
out.push(h);
}
out
}
}
#[cfg(feature = "genesis_node")]
#[derive(Debug, Clone)]
pub struct MetalBootstrap {
inner: CpuBootstrap,
}
#[cfg(feature = "genesis_node")]
impl MetalBootstrap {
#[must_use]
pub fn new(max_depth: usize, alpha: f64, alphabet_size: usize) -> Self {
Self {
inner: CpuBootstrap::new(max_depth, alpha, alphabet_size),
}
}
}
#[cfg(feature = "genesis_node")]
impl Default for MetalBootstrap {
fn default() -> Self {
Self {
inner: CpuBootstrap::default(),
}
}
}
#[cfg(feature = "genesis_node")]
impl BootstrapBackend for MetalBootstrap {
fn resample_and_estimate(&self, data: &[u8], b: usize) -> Vec<f64> {
self.inner.resample_and_estimate(data, b)
}
}
fn empirical_pi(states: &[CausalState], symbols: &[u8], max_depth: usize) -> Vec<f64> {
let k = states.len();
let mut visits = vec![0u64; k];
let n = symbols.len();
let mut assignment: std::collections::HashMap<Vec<u8>, usize> =
std::collections::HashMap::new();
for s in states {
for h in &s.histories {
assignment.insert(h.clone(), s.id);
}
}
for pos in max_depth..n {
let mut assigned = false;
for d in (1..=max_depth).rev() {
let hist = &symbols[pos - d..pos];
if let Some(&sid) = assignment.get(hist) {
visits[sid] += 1;
assigned = true;
break;
}
}
if !assigned {
visits[0] += 1;
}
}
let total: u64 = visits.iter().sum();
if total == 0 {
return vec![1.0 / k as f64; k];
}
visits.iter().map(|&v| v as f64 / total as f64).collect()
}
fn resample_states(states: &[CausalState], rng: &mut Xorshift64) -> Vec<CausalState> {
states
.iter()
.map(|s| {
let total: u32 = s.pooled.iter().sum();
let probs = counts_to_probs(&s.pooled);
let mut new_counts = vec![0u32; s.pooled.len()];
for _ in 0..total {
let sym = rng.sample_categorical(&probs);
new_counts[sym] += 1;
}
CausalState {
id: s.id,
pooled: new_counts,
histories: s.histories.clone(),
}
})
.collect()
}
fn compute_ch(states: &[CausalState], pi: &[f64]) -> (f64, f64) {
let c_mu: f64 = pi
.iter()
.filter(|&&p| p > 1e-300)
.map(|&p| -p * p.log2())
.sum();
let h_mu: f64 = states
.iter()
.zip(pi.iter())
.map(|(s, &pi_i)| {
let probs = counts_to_probs(&s.pooled);
let h: f64 = probs
.iter()
.filter(|&&p| p > 1e-300)
.map(|&p| -p * p.log2())
.sum();
pi_i * h
})
.sum();
(c_mu, h_mu)
}
struct Xorshift64(u64);
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self(if seed == 0 {
0xcafe_babe_1234_5678
} else {
seed
})
}
fn next_u64(&mut self) -> u64 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 7;
self.0 ^= self.0 << 17;
self.0
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
fn sample_categorical(&mut self, probs: &[f64]) -> usize {
let u = self.next_f64();
let mut cum = 0.0;
for (i, &p) in probs.iter().enumerate() {
cum += p;
if u < cum {
return i;
}
}
probs.len() - 1
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::gen_even_process;
#[test]
fn cpu_bootstrap_empty_input_returns_empty() {
let backend = CpuBootstrap::default();
assert!(backend.resample_and_estimate(&[], 200).is_empty());
}
#[test]
fn cpu_bootstrap_zero_b_returns_empty() {
let backend = CpuBootstrap::default();
let data: Vec<u8> = gen_even_process(500, 1);
assert!(backend.resample_and_estimate(&data, 0).is_empty());
}
#[test]
fn cpu_bootstrap_output_length_is_2b() {
let backend = CpuBootstrap::default();
let data = gen_even_process(1000, 42);
let out = backend.resample_and_estimate(&data, 20);
assert_eq!(out.len(), 40, "expected 2×b=40 entries");
}
#[test]
fn cpu_bootstrap_c_mu_non_negative() {
let backend = CpuBootstrap::default();
let data = gen_even_process(2000, 7);
let out = backend.resample_and_estimate(&data, 50);
for i in (0..out.len()).step_by(2) {
assert!(out[i] >= 0.0, "C_μ[{}]={} < 0", i / 2, out[i]);
}
}
#[test]
fn cpu_bootstrap_h_mu_non_negative() {
let backend = CpuBootstrap::default();
let data = gen_even_process(2000, 13);
let out = backend.resample_and_estimate(&data, 50);
for i in (1..out.len()).step_by(2) {
assert!(out[i] >= 0.0, "h_μ[{}]={} < 0", i / 2, out[i]);
}
}
#[test]
fn xorshift_output_in_unit_interval() {
let mut rng = Xorshift64::new(99);
for _ in 0..1000 {
let v = rng.next_f64();
assert!((0.0..1.0).contains(&v));
}
}
#[cfg(feature = "genesis_node")]
#[test]
fn metal_bootstrap_delegates_to_cpu() {
let cpu = CpuBootstrap::default();
let metal = MetalBootstrap::default();
let data = gen_even_process(500, 5);
let cpu_out = cpu.resample_and_estimate(&data, 10);
let metal_out = metal.resample_and_estimate(&data, 10);
assert_eq!(cpu_out.len(), metal_out.len());
}
}