use crate::dtype::Element;
use crate::error::{Error, Result};
pub const DIVISION_EPSILON: f64 = 1e-10;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum Interpolation {
#[default]
Linear,
Lower,
Higher,
Nearest,
Midpoint,
}
impl Interpolation {
pub fn parse(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"linear" => Ok(Interpolation::Linear),
"lower" => Ok(Interpolation::Lower),
"higher" => Ok(Interpolation::Higher),
"nearest" => Ok(Interpolation::Nearest),
"midpoint" => Ok(Interpolation::Midpoint),
_ => Err(Error::InvalidArgument {
arg: "interpolation",
reason: format!(
"Invalid interpolation method '{}'. Valid options: linear, lower, higher, nearest, midpoint",
s
),
}),
}
}
#[inline]
pub fn interpolate(&self, lower_val: f64, upper_val: f64, frac: f64) -> f64 {
match self {
Interpolation::Linear => lower_val * (1.0 - frac) + upper_val * frac,
Interpolation::Lower => lower_val,
Interpolation::Higher => upper_val,
Interpolation::Nearest => {
if frac < 0.5 {
lower_val
} else {
upper_val
}
}
Interpolation::Midpoint => (lower_val + upper_val) / 2.0,
}
}
}
#[inline]
pub fn compute_bin_edges_f64(min_val: f64, max_val: f64, bins: usize) -> Vec<f64> {
let bin_width = (max_val - min_val) / bins as f64;
(0..=bins).map(|i| min_val + i as f64 * bin_width).collect()
}
#[inline]
pub fn compute_bin_index(value: f64, min_val: f64, bin_width: f64, bins: usize) -> usize {
let idx = ((value - min_val) / bin_width).floor() as isize;
if idx < 0 {
0
} else if idx >= bins as isize {
bins - 1
} else {
idx as usize
}
}
#[inline]
pub fn compute_quantile_indices(q: f64, n: usize) -> (usize, usize, f64) {
debug_assert!(n > 0, "Array size must be positive");
debug_assert!((0.0..=1.0).contains(&q), "Quantile must be in [0, 1]");
let virtual_idx = q * (n - 1) as f64;
let floor_idx = virtual_idx.floor() as usize;
let ceil_idx = (virtual_idx.ceil() as usize).min(n - 1);
let frac = virtual_idx - floor_idx as f64;
(floor_idx, ceil_idx, frac)
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub fn skew_composite<R, C>(
client: &C,
a: &crate::tensor::Tensor<R>,
dims: &[usize],
keepdim: bool,
correction: usize,
) -> Result<crate::tensor::Tensor<R>>
where
R: crate::runtime::Runtime<DType = crate::dtype::DType>,
C: crate::ops::BinaryOps<R>
+ crate::ops::ReduceOps<R>
+ crate::ops::StatisticalOps<R>
+ crate::runtime::RuntimeClient<R>,
{
let dtype = a.dtype();
let mean = client.mean(a, dims, true)?;
let centered = client.sub(a, &mean)?;
let centered_sq = client.mul(¢ered, ¢ered)?;
let centered_cubed = client.mul(¢ered_sq, ¢ered)?;
let m3 = client.mean(¢ered_cubed, dims, keepdim)?;
let std_val = client.std(a, dims, keepdim, correction)?;
let std_sq = client.mul(&std_val, &std_val)?;
let std_cubed = client.mul(&std_sq, &std_val)?;
let epsilon = crate::tensor::Tensor::<R>::full_scalar(
std_cubed.shape(),
dtype,
DIVISION_EPSILON,
client.device(),
);
let std_cubed_safe = client.add(&std_cubed, &epsilon)?;
client.div(&m3, &std_cubed_safe)
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub fn kurtosis_composite<R, C>(
client: &C,
a: &crate::tensor::Tensor<R>,
dims: &[usize],
keepdim: bool,
correction: usize,
) -> Result<crate::tensor::Tensor<R>>
where
R: crate::runtime::Runtime<DType = crate::dtype::DType>,
C: crate::ops::BinaryOps<R>
+ crate::ops::ReduceOps<R>
+ crate::ops::StatisticalOps<R>
+ crate::runtime::RuntimeClient<R>,
{
let dtype = a.dtype();
let mean = client.mean(a, dims, true)?;
let centered = client.sub(a, &mean)?;
let centered_sq = client.mul(¢ered, ¢ered)?;
let centered_fourth = client.mul(¢ered_sq, ¢ered_sq)?;
let m4 = client.mean(¢ered_fourth, dims, keepdim)?;
let std_val = client.std(a, dims, keepdim, correction)?;
let std_sq = client.mul(&std_val, &std_val)?;
let std_fourth = client.mul(&std_sq, &std_sq)?;
let epsilon = crate::tensor::Tensor::<R>::full_scalar(
std_fourth.shape(),
dtype,
DIVISION_EPSILON,
client.device(),
);
let std_fourth_safe = client.add(&std_fourth, &epsilon)?;
let ratio = client.div(&m4, &std_fourth_safe)?;
let three = crate::tensor::Tensor::<R>::full_scalar(ratio.shape(), dtype, 3.0, client.device());
client.sub(&ratio, &three)
}
pub fn compute_skewness<T: Element>(data: &[T], _correction: usize) -> f64 {
let n = data.len();
if n < 3 {
return 0.0;
}
let sum: f64 = data.iter().map(|v| v.to_f64()).sum();
let mean = sum / n as f64;
let mut m2 = 0.0f64;
let mut m3 = 0.0f64;
for &val in data {
let diff = val.to_f64() - mean;
let diff2 = diff * diff;
m2 += diff2;
m3 += diff2 * diff;
}
m2 /= n as f64;
m3 /= n as f64;
let std = m2.sqrt();
if std < DIVISION_EPSILON {
0.0
} else {
m3 / (std * std * std)
}
}
pub fn compute_kurtosis<T: Element>(data: &[T], _correction: usize) -> f64 {
let n = data.len();
if n < 4 {
return 0.0;
}
let sum: f64 = data.iter().map(|v| v.to_f64()).sum();
let mean = sum / n as f64;
let mut m2 = 0.0f64;
let mut m4 = 0.0f64;
for &val in data {
let diff = val.to_f64() - mean;
let diff2 = diff * diff;
m2 += diff2;
m4 += diff2 * diff2;
}
m2 /= n as f64;
m4 /= n as f64;
if m2 < DIVISION_EPSILON {
0.0
} else {
m4 / (m2 * m2) - 3.0
}
}
#[derive(Debug, Clone)]
pub struct ModeResult<T> {
pub value: T,
pub count: i64,
}
pub fn compute_mode<T: Element>(sorted: &[T]) -> ModeResult<T> {
debug_assert!(!sorted.is_empty(), "Cannot compute mode of empty slice");
if sorted.len() == 1 {
return ModeResult {
value: sorted[0],
count: 1,
};
}
let mut best_value = sorted[0];
let mut best_count: i64 = 1;
let mut current_value = sorted[0];
let mut current_count: i64 = 1;
for &val in &sorted[1..] {
if val.to_f64() == current_value.to_f64() {
current_count += 1;
} else {
if current_count > best_count {
best_value = current_value;
best_count = current_count;
}
current_value = val;
current_count = 1;
}
}
if current_count > best_count {
best_value = current_value;
best_count = current_count;
}
ModeResult {
value: best_value,
count: best_count,
}
}
pub fn compute_mode_strided<T: Element>(
sorted: &[T],
outer_size: usize,
reduce_size: usize,
inner_size: usize,
) -> (Vec<T>, Vec<i64>) {
let output_size = outer_size * inner_size;
let mut values = Vec::with_capacity(output_size);
let mut counts = Vec::with_capacity(output_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
let mut slice_data = Vec::with_capacity(reduce_size);
for r in 0..reduce_size {
let idx = outer * reduce_size * inner_size + r * inner_size + inner;
slice_data.push(sorted[idx]);
}
slice_data.sort_by(|a, b| {
a.to_f64()
.partial_cmp(&b.to_f64())
.unwrap_or(std::cmp::Ordering::Equal)
});
let result = compute_mode(&slice_data);
values.push(result.value);
counts.push(result.count);
}
}
(values, counts)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interpolation_from_str() {
assert_eq!(
Interpolation::parse("linear").unwrap(),
Interpolation::Linear
);
assert_eq!(
Interpolation::parse("LINEAR").unwrap(),
Interpolation::Linear
);
assert_eq!(Interpolation::parse("lower").unwrap(), Interpolation::Lower);
assert_eq!(
Interpolation::parse("higher").unwrap(),
Interpolation::Higher
);
assert_eq!(
Interpolation::parse("nearest").unwrap(),
Interpolation::Nearest
);
assert_eq!(
Interpolation::parse("midpoint").unwrap(),
Interpolation::Midpoint
);
assert!(Interpolation::parse("invalid").is_err());
}
#[test]
fn test_interpolation_values() {
let lower = 1.0;
let upper = 2.0;
let frac = 0.75;
assert!((Interpolation::Linear.interpolate(lower, upper, frac) - 1.75).abs() < 1e-10);
assert!((Interpolation::Lower.interpolate(lower, upper, frac) - 1.0).abs() < 1e-10);
assert!((Interpolation::Higher.interpolate(lower, upper, frac) - 2.0).abs() < 1e-10);
assert!((Interpolation::Nearest.interpolate(lower, upper, frac) - 2.0).abs() < 1e-10);
assert!((Interpolation::Midpoint.interpolate(lower, upper, frac) - 1.5).abs() < 1e-10);
}
#[test]
fn test_compute_bin_edges() {
let edges = compute_bin_edges_f64(0.0, 10.0, 5);
assert_eq!(edges.len(), 6);
assert!((edges[0] - 0.0).abs() < 1e-10);
assert!((edges[2] - 4.0).abs() < 1e-10);
assert!((edges[5] - 10.0).abs() < 1e-10);
}
#[test]
fn test_compute_bin_index() {
let min = 0.0;
let width = 2.0;
let bins = 5;
assert_eq!(compute_bin_index(0.5, min, width, bins), 0);
assert_eq!(compute_bin_index(2.5, min, width, bins), 1);
assert_eq!(compute_bin_index(-1.0, min, width, bins), 0); assert_eq!(compute_bin_index(100.0, min, width, bins), 4); }
#[test]
fn test_compute_quantile_indices() {
let (f, c, frac) = compute_quantile_indices(0.5, 5);
assert_eq!(f, 2);
assert_eq!(c, 2);
assert!(frac.abs() < 1e-10);
let (f, c, frac) = compute_quantile_indices(0.25, 5);
assert_eq!(f, 1);
assert_eq!(c, 1);
assert!(frac.abs() < 1e-10);
}
#[test]
fn test_skewness_symmetric() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let skew = compute_skewness(&data, 0);
assert!(
skew.abs() < 0.1,
"Symmetric data should have near-zero skewness"
);
}
#[test]
fn test_kurtosis_uniform() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let kurt = compute_kurtosis(&data, 0);
assert!(
kurt < 0.0,
"Uniform-like data should have negative kurtosis"
);
}
#[test]
fn test_compute_mode_simple() {
let data: Vec<f32> = vec![1.0, 2.0, 2.0, 2.0, 3.0];
let result = compute_mode(&data);
assert!((result.value - 2.0).abs() < 1e-10);
assert_eq!(result.count, 3);
}
#[test]
fn test_compute_mode_all_unique() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = compute_mode(&data);
assert!((result.value - 1.0).abs() < 1e-10);
assert_eq!(result.count, 1);
}
#[test]
fn test_compute_mode_tie() {
let data: Vec<f32> = vec![1.0, 1.0, 2.0, 3.0, 3.0];
let result = compute_mode(&data);
assert!((result.value - 1.0).abs() < 1e-10);
assert_eq!(result.count, 2);
}
#[test]
fn test_compute_mode_single_element() {
let data: Vec<f32> = vec![42.0];
let result = compute_mode(&data);
assert!((result.value - 42.0).abs() < 1e-10);
assert_eq!(result.count, 1);
}
#[test]
fn test_compute_mode_all_same() {
let data: Vec<f32> = vec![7.0, 7.0, 7.0, 7.0];
let result = compute_mode(&data);
assert!((result.value - 7.0).abs() < 1e-10);
assert_eq!(result.count, 4);
}
}