use super::super::helpers::{dispatch_dtype, ensure_contiguous};
use super::super::{CpuClient, CpuRuntime};
use crate::dtype::Element;
use crate::error::Result;
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, StatisticalOps};
use crate::runtime::common::statistics_common::{
DIVISION_EPSILON, compute_kurtosis, compute_skewness,
};
use crate::tensor::Tensor;
pub fn skew_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dims: &[usize],
keepdim: bool,
correction: usize,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dims.is_empty() {
let numel = a.numel();
let a_contig = ensure_contiguous(a);
let a_ptr = a_contig.ptr();
let skewness = dispatch_dtype!(dtype, T => {
unsafe {
let slice = std::slice::from_raw_parts(a_ptr as *const T, numel);
compute_skewness(slice, correction)
}
}, "skew");
let out_shape = if keepdim { vec![1; ndim] } else { vec![] };
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe { *(out_ptr as *mut T) = T::from_f64(skewness); }
}, "skew");
return Ok(out);
}
let mean = client.mean(a, dims, true)?;
let centered = client.sub(a, &mean)?;
let centered_cubed = client.pow_scalar(¢ered, 3.0)?;
let m3 = client.mean(¢ered_cubed, dims, keepdim)?;
let std_val = client.std(a, dims, keepdim, correction)?;
let std_cubed = client.pow_scalar(&std_val, 3.0)?;
let epsilon = Tensor::<CpuRuntime>::full_scalar(
std_cubed.shape(),
dtype,
DIVISION_EPSILON,
&client.device,
);
let std_cubed_safe = client.add(&std_cubed, &epsilon)?;
client.div(&m3, &std_cubed_safe)
}
pub fn kurtosis_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dims: &[usize],
keepdim: bool,
correction: usize,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dims.is_empty() {
let numel = a.numel();
let a_contig = ensure_contiguous(a);
let a_ptr = a_contig.ptr();
let kurtosis = dispatch_dtype!(dtype, T => {
unsafe {
let slice = std::slice::from_raw_parts(a_ptr as *const T, numel);
compute_kurtosis(slice, correction)
}
}, "kurtosis");
let out_shape = if keepdim { vec![1; ndim] } else { vec![] };
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe { *(out_ptr as *mut T) = T::from_f64(kurtosis); }
}, "kurtosis");
return Ok(out);
}
let mean = client.mean(a, dims, true)?;
let centered = client.sub(a, &mean)?;
let centered_fourth = client.pow_scalar(¢ered, 4.0)?;
let m4 = client.mean(¢ered_fourth, dims, keepdim)?;
let std_val = client.std(a, dims, keepdim, correction)?;
let std_fourth = client.pow_scalar(&std_val, 4.0)?;
let epsilon = Tensor::<CpuRuntime>::full_scalar(
std_fourth.shape(),
dtype,
DIVISION_EPSILON,
&client.device,
);
let std_fourth_safe = client.add(&std_fourth, &epsilon)?;
let ratio = client.div(&m4, &std_fourth_safe)?;
let three = Tensor::<CpuRuntime>::full_scalar(ratio.shape(), dtype, 3.0, &client.device);
client.sub(&ratio, &three)
}