use crate::CausalTensor;
use crate::CausalTensorError;
use deep_causality_num::{FromPrimitive, RealField};
pub trait CausalTensorStatsExt<T> {
fn sample_mean(&self) -> Result<CausalTensor<T>, CausalTensorError>;
fn sample_covariance(&self) -> Result<CausalTensor<T>, CausalTensorError>;
fn logsumexp(&self) -> T;
fn gaussian_log_density(
&self,
mean: T,
variance: T,
) -> Result<CausalTensor<T>, CausalTensorError>;
fn conditional_variance(
&self,
target: usize,
parents: &[usize],
ridge: T,
) -> Result<T, CausalTensorError>;
}
impl<T> CausalTensorStatsExt<T> for CausalTensor<T>
where
T: RealField + FromPrimitive,
{
fn sample_mean(&self) -> Result<CausalTensor<T>, CausalTensorError> {
let (n, k) = observations_shape(self)?;
let data = self.as_slice();
let n_t =
<T as FromPrimitive>::from_usize(n).ok_or(CausalTensorError::InvalidParameter(
"observation count is not representable in the tensor's field".to_string(),
))?;
let mut means = vec![T::zero(); k];
for row in 0..n {
let base = row * k;
for (col, mean) in means.iter_mut().enumerate() {
*mean += data[base + col];
}
}
for mean in means.iter_mut() {
*mean /= n_t;
}
Ok(CausalTensor::from_slice(&means, &[k]))
}
fn sample_covariance(&self) -> Result<CausalTensor<T>, CausalTensorError> {
let (n, k) = observations_shape(self)?;
if n < 2 {
return Err(CausalTensorError::InvalidParameter(format!(
"sample covariance needs at least 2 observations, got {n}"
)));
}
let data = self.as_slice();
let means = self.sample_mean()?;
let means = means.as_slice();
let denom =
<T as FromPrimitive>::from_usize(n - 1).ok_or(CausalTensorError::InvalidParameter(
"observation count is not representable in the tensor's field".to_string(),
))?;
let mut cov = vec![T::zero(); k * k];
for row in 0..n {
let base = row * k;
for i in 0..k {
let di = data[base + i] - means[i];
for j in 0..k {
let dj = data[base + j] - means[j];
cov[i * k + j] += di * dj;
}
}
}
for entry in cov.iter_mut() {
*entry /= denom;
}
Ok(CausalTensor::from_slice(&cov, &[k, k]))
}
fn logsumexp(&self) -> T {
let xs = self.as_slice();
if xs.is_empty() {
return T::zero().ln();
}
let max = xs
.iter()
.copied()
.fold(xs[0], |acc, x| if x > acc { x } else { acc });
if !max.is_finite() {
return max;
}
let sum = xs
.iter()
.copied()
.fold(T::zero(), |acc, x| acc + (x - max).exp());
max + sum.ln()
}
fn gaussian_log_density(
&self,
mean: T,
variance: T,
) -> Result<CausalTensor<T>, CausalTensorError> {
if self.is_empty() {
return Ok(CausalTensor::from_slice(&[], self.shape()));
}
let var = if variance > T::zero() {
variance
} else {
variance_floor::<T>()
};
let half =
<T as FromPrimitive>::from_f64(0.5).expect("0.5 is representable in every RealField");
let two =
<T as FromPrimitive>::from_f64(2.0).expect("2.0 is representable in every RealField");
let log_two_pi_var = (two * T::pi() * var).ln();
let new_data: Vec<T> = self
.as_slice()
.iter()
.map(|&x| {
let diff = x - mean;
-half * (log_two_pi_var + (diff * diff) / var)
})
.collect();
Ok(CausalTensor::from_slice(&new_data, self.shape()))
}
fn conditional_variance(
&self,
target: usize,
parents: &[usize],
ridge: T,
) -> Result<T, CausalTensorError> {
let shape = self.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(CausalTensorError::DimensionMismatch);
}
let m = shape[0];
let entry = |i: usize, j: usize| -> Result<T, CausalTensorError> {
self.get(&[i, j])
.copied()
.ok_or(CausalTensorError::IndexOutOfBounds)
};
if target >= m || !parents.iter().all(|&p| p < m) {
return Err(CausalTensorError::IndexOutOfBounds);
}
let sigma_yy = entry(target, target)?;
let k = parents.len();
if k == 0 {
return Ok(sigma_yy);
}
let mut sigma_yp = vec![T::zero(); k];
for (slot, &p) in sigma_yp.iter_mut().zip(parents.iter()) {
*slot = entry(target, p)?;
}
let mut sigma_pp = vec![T::zero(); k * k];
for (i, &pi) in parents.iter().enumerate() {
for (j, &pj) in parents.iter().enumerate() {
let mut v = entry(pi, pj)?;
if i == j {
v += ridge;
}
sigma_pp[i * k + j] = v;
}
}
cholesky_in_place(&mut sigma_pp, k);
let mut z = sigma_yp.clone();
cholesky_solve_in_place(&sigma_pp, &mut z, k);
let mut reduction = T::zero();
for i in 0..k {
reduction += sigma_yp[i] * z[i];
}
Ok(sigma_yy - reduction)
}
}
fn variance_floor<T>() -> T
where
T: RealField + FromPrimitive,
{
<T as FromPrimitive>::from_f64(1e-12).expect("1e-12 is representable in every RealField")
}
fn cholesky_in_place<T>(a: &mut [T], k: usize)
where
T: RealField,
{
for j in 0..k {
let mut diag = a[j * k + j];
for p in 0..j {
let l_jp = a[j * k + p];
diag -= l_jp * l_jp;
}
let pivot = if diag > T::zero() { diag } else { T::epsilon() };
let l_jj = pivot.sqrt();
a[j * k + j] = l_jj;
for i in (j + 1)..k {
let mut s = a[i * k + j];
for p in 0..j {
s -= a[i * k + p] * a[j * k + p];
}
a[i * k + j] = s / l_jj;
}
}
}
fn cholesky_solve_in_place<T>(l: &[T], b: &mut [T], k: usize)
where
T: RealField,
{
for i in 0..k {
let mut s = b[i];
for p in 0..i {
s -= l[i * k + p] * b[p];
}
b[i] = s / l[i * k + i];
}
for i in (0..k).rev() {
let mut s = b[i];
for p in (i + 1)..k {
s -= l[p * k + i] * b[p];
}
b[i] = s / l[i * k + i];
}
}
fn observations_shape<T>(tensor: &CausalTensor<T>) -> Result<(usize, usize), CausalTensorError> {
let shape = tensor.shape();
if shape.len() != 2 {
return Err(CausalTensorError::DimensionMismatch);
}
let (n, k) = (shape[0], shape[1]);
if n == 0 || k == 0 {
return Err(CausalTensorError::EmptyTensor);
}
Ok((n, k))
}