use crate::DType;
use crate::stats::helpers::extract_scalar;
use crate::stats::{TensorDescriptiveStats, validate_stats_dtype};
use numr::error::{Error, Result};
use numr::ops::{StatisticalOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn describe_impl<R, C>(client: &C, x: &Tensor<R>) -> Result<TensorDescriptiveStats<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
if x.numel() == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "cannot compute statistics on empty tensor".to_string(),
});
}
let x_contig = x.contiguous()?;
let n = x_contig.numel();
let n_f = n as f64;
let all_dims: Vec<usize> = (0..x_contig.ndim()).collect();
let mean_tensor = client.mean(&x_contig, &all_dims, false)?;
let var_tensor = client.var(&x_contig, &all_dims, false, 1)?;
let min_tensor = client.min(&x_contig, &all_dims, false)?;
let max_tensor = client.max(&x_contig, &all_dims, false)?;
let std_tensor = client.std(&x_contig, &all_dims, false, 1)?;
let mean_val = extract_scalar(&mean_tensor)?;
let mean_broadcast =
Tensor::<R>::full_scalar(x_contig.shape(), x.dtype(), mean_val, client.device());
let centered = client.sub(&x_contig, &mean_broadcast)?;
let centered_sq = client.mul(¢ered, ¢ered)?;
let centered_cu = client.mul(¢ered_sq, ¢ered)?;
let centered_qu = client.mul(¢ered_sq, ¢ered_sq)?;
let m2 = extract_scalar(&client.sum(¢ered_sq, &all_dims, false)?)?;
let m3 = extract_scalar(&client.sum(¢ered_cu, &all_dims, false)?)?;
let m4 = extract_scalar(&client.sum(¢ered_qu, &all_dims, false)?)?;
let skewness_val = if n > 2 && m2 > 0.0 {
let m2_norm = m2 / n_f;
(m3 / n_f) / m2_norm.powf(1.5)
} else {
0.0
};
let kurtosis_val = if n > 3 && m2 > 0.0 {
let m2_norm = m2 / n_f;
(m4 / n_f) / (m2_norm * m2_norm) - 3.0
} else {
0.0
};
let skewness_tensor = Tensor::<R>::full_scalar(&[], x.dtype(), skewness_val, client.device());
let kurtosis_tensor = Tensor::<R>::full_scalar(&[], x.dtype(), kurtosis_val, client.device());
Ok(TensorDescriptiveStats {
nobs: n,
min: min_tensor,
max: max_tensor,
mean: mean_tensor,
variance: var_tensor,
std: std_tensor,
skewness: skewness_tensor,
kurtosis: kurtosis_tensor,
})
}
pub fn percentile_impl<R, C>(client: &C, x: &Tensor<R>, p: f64) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + StatisticalOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
if !(0.0..=100.0).contains(&p) {
return Err(Error::InvalidArgument {
arg: "p",
reason: format!("percentile must be in [0, 100], got {}", p),
});
}
client.percentile(x, p, None, false)
}
pub fn iqr_impl<R, C>(client: &C, x: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
let q1 = percentile_impl(client, x, 25.0)?;
let q3 = percentile_impl(client, x, 75.0)?;
client.sub(&q3, &q1)
}
pub fn skewness_impl<R, C>(client: &C, x: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
let x_contig = x.contiguous()?;
let n = x_contig.numel();
if n < 3 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "skewness requires at least 3 samples".to_string(),
});
}
let all_dims: Vec<usize> = (0..x_contig.ndim()).collect();
let mean_tensor = client.mean(&x_contig, &all_dims, false)?;
let mean_val = extract_scalar(&mean_tensor)?;
let mean_broadcast =
Tensor::<R>::full_scalar(x_contig.shape(), x.dtype(), mean_val, client.device());
let centered = client.sub(&x_contig, &mean_broadcast)?;
let centered_sq = client.mul(¢ered, ¢ered)?;
let centered_cu = client.mul(¢ered_sq, ¢ered)?;
let m2 = extract_scalar(&client.sum(¢ered_sq, &all_dims, false)?)?;
let m3 = extract_scalar(&client.sum(¢ered_cu, &all_dims, false)?)?;
let n_f = n as f64;
let skew = if m2 > 0.0 {
let m2_norm = m2 / n_f;
(m3 / n_f) / m2_norm.powf(1.5)
} else {
0.0
};
Ok(Tensor::<R>::full_scalar(
&[],
x.dtype(),
skew,
client.device(),
))
}
pub fn kurtosis_impl<R, C>(client: &C, x: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
let x_contig = x.contiguous()?;
let n = x_contig.numel();
if n < 4 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "kurtosis requires at least 4 samples".to_string(),
});
}
let all_dims: Vec<usize> = (0..x_contig.ndim()).collect();
let mean_tensor = client.mean(&x_contig, &all_dims, false)?;
let mean_val = extract_scalar(&mean_tensor)?;
let mean_broadcast =
Tensor::<R>::full_scalar(x_contig.shape(), x.dtype(), mean_val, client.device());
let centered = client.sub(&x_contig, &mean_broadcast)?;
let centered_sq = client.mul(¢ered, ¢ered)?;
let centered_qu = client.mul(¢ered_sq, ¢ered_sq)?;
let m2 = extract_scalar(&client.sum(¢ered_sq, &all_dims, false)?)?;
let m4 = extract_scalar(&client.sum(¢ered_qu, &all_dims, false)?)?;
let n_f = n as f64;
let kurt = if m2 > 0.0 {
let m2_norm = m2 / n_f;
(m4 / n_f) / (m2_norm * m2_norm) - 3.0
} else {
0.0
};
Ok(Tensor::<R>::full_scalar(
&[],
x.dtype(),
kurt,
client.device(),
))
}
pub fn zscore_impl<R, C>(client: &C, x: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
let x_contig = x.contiguous()?;
let all_dims: Vec<usize> = (0..x_contig.ndim()).collect();
let mean_val = extract_scalar(&client.mean(&x_contig, &all_dims, false)?)?;
let std_val = extract_scalar(&client.std(&x_contig, &all_dims, false, 1)?)?;
if std_val == 0.0 {
return Ok(Tensor::<R>::zeros(
x_contig.shape(),
x.dtype(),
client.device(),
));
}
let mean_broadcast =
Tensor::<R>::full_scalar(x_contig.shape(), x.dtype(), mean_val, client.device());
let centered = client.sub(&x_contig, &mean_broadcast)?;
let std_broadcast =
Tensor::<R>::full_scalar(x_contig.shape(), x.dtype(), std_val, client.device());
client.div(¢ered, &std_broadcast)
}
pub fn sem_impl<R, C>(client: &C, x: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
let x_contig = x.contiguous()?;
let n = x_contig.numel() as f64;
let all_dims: Vec<usize> = (0..x_contig.ndim()).collect();
let std_val = extract_scalar(&client.std(&x_contig, &all_dims, false, 1)?)?;
let sem_val = std_val / n.sqrt();
Ok(Tensor::<R>::full_scalar(
&[],
x.dtype(),
sem_val,
client.device(),
))
}