use crate::error::ConsciousnessError;
use crate::simd::marginal_distribution;
use crate::types::{PidResult, TransitionMatrix};
use std::time::Instant;
pub fn compute_pid(
tpm: &TransitionMatrix,
sources: &[Vec<usize>],
target: &[usize],
) -> Result<PidResult, ConsciousnessError> {
let n = tpm.n;
if n < 2 {
return Err(crate::error::ValidationError::EmptySystem.into());
}
if sources.is_empty() {
return Err(crate::error::ValidationError::DimensionMismatch(
"need at least one source".into(),
)
.into());
}
let start = Instant::now();
let marginal = marginal_distribution(tpm.as_slice(), n);
let mut source_mis: Vec<f64> = Vec::with_capacity(sources.len());
for source in sources {
let mi = source_target_mi(tpm, n, source, target, &marginal);
source_mis.push(mi);
}
let all_sources: Vec<usize> = sources.iter().flat_map(|s| s.iter().copied()).collect();
let total_mi = source_target_mi(tpm, n, &all_sources, target, &marginal);
let redundancy = williams_beer_imin(tpm, n, sources, target, &marginal);
let mut unique: Vec<f64> = Vec::with_capacity(sources.len());
for &mi in &source_mis {
unique.push((mi - redundancy).max(0.0));
}
let unique_sum: f64 = unique.iter().sum();
let synergy = (total_mi - redundancy - unique_sum).max(0.0);
Ok(PidResult {
redundancy,
unique,
synergy,
total_mi,
num_sources: sources.len(),
elapsed: start.elapsed(),
})
}
fn williams_beer_imin(
tpm: &TransitionMatrix,
n: usize,
sources: &[Vec<usize>],
target: &[usize],
marginal: &[f64],
) -> f64 {
let target_marginal = compute_target_marginal(tpm, n, target);
let target_size = target_marginal.len();
let source_marginals: Vec<Vec<f64>> = sources
.iter()
.map(|s| compute_source_marginal(tpm, n, s))
.collect();
let mut imin = 0.0f64;
for t_state in 0..target_size {
let p_t = target_marginal[t_state];
if p_t < 1e-15 {
continue;
}
let mut min_spec = f64::MAX;
for (source, source_marginal) in sources.iter().zip(source_marginals.iter()) {
let spec = specific_information_cached(
tpm, n, source, target, t_state, &target_marginal, source_marginal,
);
min_spec = min_spec.min(spec);
}
if min_spec < f64::MAX {
imin += p_t * min_spec;
}
}
imin.max(0.0)
}
fn specific_information_cached(
tpm: &TransitionMatrix,
n: usize,
source: &[usize],
target: &[usize],
target_state: usize,
target_marginal: &[f64],
source_marginal: &[f64],
) -> f64 {
let source_size = source_marginal.len();
let p_t = target_marginal[target_state];
if p_t < 1e-15 {
return 0.0;
}
let mut p_s_given_t = vec![0.0f64; source_size];
let inv_n = 1.0 / n as f64;
for global_state in 0..n {
let s_state = extract_substate(global_state, source);
if s_state < source_size {
let mut p_target_given_global = 0.0;
for future in 0..n {
if extract_substate(future, target) == target_state {
p_target_given_global += tpm.get(global_state, future);
}
}
p_s_given_t[s_state] += inv_n * p_target_given_global;
}
}
let sum: f64 = p_s_given_t.iter().sum();
if sum > 1e-15 {
let inv = 1.0 / sum;
for p in &mut p_s_given_t {
*p *= inv;
}
}
let mut dkl = 0.0f64;
for i in 0..source_size {
let p = p_s_given_t[i];
let q = source_marginal[i];
if p > 1e-15 && q > 1e-15 {
dkl += p * (p / q).ln();
}
}
dkl.max(0.0)
}
fn specific_information(
tpm: &TransitionMatrix,
n: usize,
source: &[usize],
target: &[usize],
target_state: usize,
target_marginal: &[f64],
) -> f64 {
let source_marginal = compute_source_marginal(tpm, n, source);
let source_size = source_marginal.len();
let p_t = target_marginal[target_state];
if p_t < 1e-15 {
return 0.0;
}
let mut p_s_given_t = vec![0.0f64; source_size];
let inv_n = 1.0 / n as f64;
for global_state in 0..n {
let s_state = extract_substate(global_state, source);
let t_state_actual = extract_substate(global_state, target);
if s_state < source_size {
let mut p_target_given_global = 0.0;
for future in 0..n {
if extract_substate(future, target) == target_state {
p_target_given_global += tpm.get(global_state, future);
}
}
p_s_given_t[s_state] += inv_n * p_target_given_global;
}
}
let sum: f64 = p_s_given_t.iter().sum();
if sum > 1e-15 {
let inv = 1.0 / sum;
for p in &mut p_s_given_t {
*p *= inv;
}
}
let mut dkl = 0.0f64;
for i in 0..source_size {
let p = p_s_given_t[i];
let q = source_marginal[i];
if p > 1e-15 && q > 1e-15 {
dkl += p * (p / q).ln();
}
}
dkl.max(0.0)
}
fn source_target_mi(
tpm: &TransitionMatrix,
n: usize,
source: &[usize],
target: &[usize],
_marginal: &[f64],
) -> f64 {
let source_marginal = compute_source_marginal(tpm, n, source);
let target_marginal = compute_target_marginal(tpm, n, target);
let joint = compute_joint_distribution(tpm, n, source, target);
let source_size = source_marginal.len();
let target_size = target_marginal.len();
let mut mi = 0.0f64;
for s in 0..source_size {
for t in 0..target_size {
let pst = joint[s * target_size + t];
let ps = source_marginal[s];
let pt = target_marginal[t];
if pst > 1e-15 && ps > 1e-15 && pt > 1e-15 {
mi += pst * (pst / (ps * pt)).ln();
}
}
}
mi.max(0.0)
}
fn compute_source_marginal(tpm: &TransitionMatrix, n: usize, source: &[usize]) -> Vec<f64> {
let size = 1usize << source.len();
let mut dist = vec![0.0f64; size];
let inv_n = 1.0 / n as f64;
for state in 0..n {
let sub = extract_substate(state, source);
if sub < size {
dist[sub] += inv_n;
}
}
dist
}
fn compute_target_marginal(tpm: &TransitionMatrix, n: usize, target: &[usize]) -> Vec<f64> {
let size = 1usize << target.len();
let mut dist = vec![0.0f64; size];
let inv_n = 1.0 / n as f64;
for state in 0..n {
for future in 0..n {
let t_sub = extract_substate(future, target);
if t_sub < size {
dist[t_sub] += inv_n * tpm.get(state, future);
}
}
}
dist
}
fn compute_joint_distribution(
tpm: &TransitionMatrix,
n: usize,
source: &[usize],
target: &[usize],
) -> Vec<f64> {
let source_size = 1usize << source.len();
let target_size = 1usize << target.len();
let mut joint = vec![0.0f64; source_size * target_size];
let inv_n = 1.0 / n as f64;
for state in 0..n {
let s_sub = extract_substate(state, source);
for future in 0..n {
let t_sub = extract_substate(future, target);
if s_sub < source_size && t_sub < target_size {
joint[s_sub * target_size + t_sub] += inv_n * tpm.get(state, future);
}
}
}
joint
}
#[inline]
fn extract_substate(global_state: usize, indices: &[usize]) -> usize {
let mut sub = 0usize;
for (bit, &idx) in indices.iter().enumerate() {
sub |= ((global_state >> idx) & 1) << bit;
}
sub
}
#[cfg(test)]
mod tests {
use super::*;
fn and_gate_tpm() -> TransitionMatrix {
#[rustfmt::skip]
let data = vec![
0.5, 0.25, 0.25, 0.0,
0.5, 0.25, 0.25, 0.0,
0.5, 0.25, 0.25, 0.0,
0.0, 0.0, 0.0, 1.0,
];
TransitionMatrix::new(4, data)
}
#[test]
fn pid_two_sources() {
let tpm = and_gate_tpm();
let sources = vec![vec![0, 1], vec![2, 3]];
let target = vec![0, 1];
let result = compute_pid(&tpm, &sources, &target).unwrap();
assert!(result.total_mi >= 0.0);
assert!(result.redundancy >= 0.0);
assert!(result.synergy >= 0.0);
assert_eq!(result.num_sources, 2);
}
#[test]
fn pid_decomposition_sums() {
let tpm = and_gate_tpm();
let sources = vec![vec![0], vec![1]];
let target = vec![0, 1];
let result = compute_pid(&tpm, &sources, &target).unwrap();
let sum = result.redundancy + result.unique.iter().sum::<f64>() + result.synergy;
assert!((sum - result.total_mi).abs() < 1e-6,
"PID sum {} should equal total MI {}", sum, result.total_mi);
}
#[test]
fn pid_rejects_empty() {
let tpm = and_gate_tpm();
assert!(compute_pid(&tpm, &[], &[0]).is_err());
}
}