use crate::DType;
use crate::stats::helpers::{extract_scalar, tensor_median_scalar};
use crate::stats::traits::{RobustRegressionResult, validate_stats_dtype};
use numr::error::{Error, Result};
use numr::ops::{CompareOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn trim_mean_impl<R, C>(client: &C, x: &Tensor<R>, proportiontocut: f64) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
if !(0.0..0.5).contains(&proportiontocut) {
return Err(Error::InvalidArgument {
arg: "proportiontocut",
reason: "must be in [0, 0.5)".to_string(),
});
}
let x_contig = x.contiguous()?;
let n = x_contig.numel();
if n < 2 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "trimmed mean requires at least 2 samples".to_string(),
});
}
let sorted = client.sort(&x_contig, 0, false)?;
let ncut = (n as f64 * proportiontocut).floor() as usize;
if 2 * ncut >= n {
return Err(Error::InvalidArgument {
arg: "proportiontocut",
reason: "proportion too large for sample size".to_string(),
});
}
let trimmed = sorted.narrow(0, ncut, n - 2 * ncut)?.contiguous()?;
let all_dims: Vec<usize> = (0..trimmed.ndim()).collect();
client.mean(&trimmed, &all_dims, false)
}
pub fn winsorized_mean_impl<R, C>(
client: &C,
x: &Tensor<R>,
proportiontocut: f64,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
if !(0.0..0.5).contains(&proportiontocut) {
return Err(Error::InvalidArgument {
arg: "proportiontocut",
reason: "must be in [0, 0.5)".to_string(),
});
}
let x_contig = x.contiguous()?;
let n = x_contig.numel();
if n < 2 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "winsorized mean requires at least 2 samples".to_string(),
});
}
let sorted = client.sort(&x_contig, 0, false)?;
let ncut = (n as f64 * proportiontocut).floor() as usize;
if 2 * ncut >= n {
return Err(Error::InvalidArgument {
arg: "proportiontocut",
reason: "proportion too large for sample size".to_string(),
});
}
let low_val = extract_scalar(&sorted.narrow(0, ncut, 1)?)?;
let high_val = extract_scalar(&sorted.narrow(0, n - ncut - 1, 1)?)?;
let clamped = client.clamp(&x_contig, low_val, high_val)?;
let all_dims: Vec<usize> = (0..clamped.ndim()).collect();
client.mean(&clamped, &all_dims, false)
}
pub fn median_abs_deviation_impl<R, C>(client: &C, x: &Tensor<R>, scale: bool) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
let x_contig = x.contiguous()?;
let n = x_contig.numel();
if n == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "MAD requires at least 1 sample".to_string(),
});
}
let median_val = tensor_median_scalar(client, &x_contig)?;
let median_t =
Tensor::<R>::full_scalar(x_contig.shape(), x.dtype(), median_val, client.device());
let deviations = client.sub(&x_contig, &median_t)?;
let abs_deviations = client.abs(&deviations)?;
let mad_val = tensor_median_scalar(client, &abs_deviations)?;
let result = if scale { mad_val * 1.4826 } else { mad_val };
Ok(Tensor::<R>::full_scalar(
&[],
x.dtype(),
result,
client.device(),
))
}
pub fn theilslopes_impl<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
) -> Result<RobustRegressionResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + CompareOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
validate_stats_dtype(y.dtype())?;
let n = x.numel();
if n != y.numel() {
return Err(Error::InvalidArgument {
arg: "x/y",
reason: "must have equal length".to_string(),
});
}
if n < 2 {
return Err(Error::InvalidArgument {
arg: "x/y",
reason: "Theil-Sen requires at least 2 points".to_string(),
});
}
let device = client.device();
let dtype = x.dtype();
let x_contig = x.contiguous()?;
let y_contig = y.contiguous()?;
let x_col = x_contig.reshape(&[n, 1])?;
let x_row = x_contig.reshape(&[1, n])?;
let y_col = y_contig.reshape(&[n, 1])?;
let y_row = y_contig.reshape(&[1, n])?;
let dx = client.sub(&x_row, &x_col)?;
let dy = client.sub(&y_row, &y_col)?;
let slopes_matrix = client.div(&dy, &dx)?;
let row_idx = client.arange(0.0, n as f64, 1.0, dtype)?.reshape(&[n, 1])?;
let col_idx = client.arange(0.0, n as f64, 1.0, dtype)?.reshape(&[1, n])?;
let mask = client.lt(&row_idx, &col_idx)?;
let eps = Tensor::<R>::full_scalar(&[1, 1], dtype, 1e-15, device);
let dx_abs = client.abs(&dx)?;
let nonzero_mask = client.gt(&dx_abs, &eps)?;
let valid_mask = client.mul(&mask, &nonzero_mask)?;
let valid_mask_u8 = client.cast(&valid_mask, numr::dtype::DType::U8)?;
let valid_slopes = client.masked_select(&slopes_matrix, &valid_mask_u8)?;
if valid_slopes.numel() == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "all x values are identical".to_string(),
});
}
let slope = tensor_median_scalar(client, &valid_slopes)?;
let slope_t = Tensor::<R>::full_scalar(x_contig.shape(), dtype, slope, device);
let slope_x = client.mul(&slope_t, &x_contig)?;
let intercepts = client.sub(&y_contig, &slope_x)?;
let intercept = tensor_median_scalar(client, &intercepts)?;
let m = valid_slopes.numel();
let sorted_slopes = client.sort(&valid_slopes, 0, false)?;
let z = 1.96;
let c =
(z * (n as f64 * (n as f64 - 1.0) * (2.0 * n as f64 + 5.0) / 18.0).sqrt()).round() as usize;
let low_slope = if c < m {
extract_scalar(&sorted_slopes.narrow(0, c, 1)?)?
} else {
extract_scalar(&sorted_slopes.narrow(0, 0, 1)?)?
};
let high_slope = if m > c {
extract_scalar(&sorted_slopes.narrow(0, m - 1 - c, 1)?)?
} else {
extract_scalar(&sorted_slopes.narrow(0, m - 1, 1)?)?
};
Ok(RobustRegressionResult {
slope: Tensor::<R>::full_scalar(&[], dtype, slope, device),
intercept: Tensor::<R>::full_scalar(&[], dtype, intercept, device),
low_slope: Tensor::<R>::full_scalar(&[], dtype, low_slope, device),
high_slope: Tensor::<R>::full_scalar(&[], dtype, high_slope, device),
})
}
pub fn siegelslopes_impl<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
) -> Result<RobustRegressionResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + CompareOps<R> + RuntimeClient<R>,
{
validate_stats_dtype(x.dtype())?;
validate_stats_dtype(y.dtype())?;
let n = x.numel();
if n != y.numel() {
return Err(Error::InvalidArgument {
arg: "x/y",
reason: "must have equal length".to_string(),
});
}
if n < 2 {
return Err(Error::InvalidArgument {
arg: "x/y",
reason: "Siegel slopes requires at least 2 points".to_string(),
});
}
let device = client.device();
let dtype = x.dtype();
let x_contig = x.contiguous()?;
let y_contig = y.contiguous()?;
let x_col = x_contig.reshape(&[n, 1])?;
let x_row = x_contig.reshape(&[1, n])?;
let y_col = y_contig.reshape(&[n, 1])?;
let y_row = y_contig.reshape(&[1, n])?;
let dx = client.sub(&x_row, &x_col)?; let dy = client.sub(&y_row, &y_col)?;
let eps = Tensor::<R>::full_scalar(&[1, 1], dtype, 1e-15, device);
let dx_abs = client.abs(&dx)?;
let nonzero_mask = client.gt(&dx_abs, &eps)?;
let row_idx = client.arange(0.0, n as f64, 1.0, dtype)?.reshape(&[n, 1])?;
let col_idx = client.arange(0.0, n as f64, 1.0, dtype)?.reshape(&[1, n])?;
let not_diag = client.ne(&row_idx, &col_idx)?;
let valid_mask = client.mul(&nonzero_mask, ¬_diag)?;
let slopes_matrix = client.div(&dy, &dx)?;
let inf_val = Tensor::<R>::full_scalar(&[n, n], dtype, f64::INFINITY, device);
let slopes_clean = client.where_cond(&valid_mask, &slopes_matrix, &inf_val)?;
let sorted_rows = client.sort(&slopes_clean, 1, false)?;
let ones = Tensor::<R>::full_scalar(&[n, n], dtype, 1.0, device);
let zeros = Tensor::<R>::full_scalar(&[n, n], dtype, 0.0, device);
let valid_float = client.where_cond(&valid_mask, &ones, &zeros)?;
let counts = client.sum(&valid_float, &[1], false)?;
let two = Tensor::<R>::full_scalar(counts.shape(), dtype, 2.0, device);
let median_indices_f = client.div(&counts, &two)?;
let median_indices_floor = client.floor(&median_indices_f)?;
let median_indices = client.cast(&median_indices_floor, numr::dtype::DType::I64)?;
let median_indices_2d = median_indices.reshape(&[n, 1])?;
let per_point_medians = client.gather(&sorted_rows, 1, &median_indices_2d)?; let per_point_medians_flat = per_point_medians.reshape(&[n])?;
let zero_t = Tensor::<R>::full_scalar(counts.shape(), dtype, 0.0, device);
let has_valid = client.gt(&counts, &zero_t)?;
let has_valid_u8 = client.cast(&has_valid, numr::dtype::DType::U8)?;
let valid_medians = client.masked_select(&per_point_medians_flat, &has_valid_u8)?;
if valid_medians.numel() == 0 {
return Err(Error::InvalidArgument {
arg: "x",
reason: "all x values are identical".to_string(),
});
}
let slope = tensor_median_scalar(client, &valid_medians)?;
let slope_t = Tensor::<R>::full_scalar(x_contig.shape(), dtype, slope, device);
let slope_x = client.mul(&slope_t, &x_contig)?;
let intercepts = client.sub(&y_contig, &slope_x)?;
let intercept = tensor_median_scalar(client, &intercepts)?;
let ms = valid_medians.numel();
let sorted_medians = client.sort(&valid_medians, 0, false)?;
let z = 1.96;
let c =
(z * (n as f64 * (n as f64 - 1.0) * (2.0 * n as f64 + 5.0) / 18.0).sqrt()).round() as usize;
let low_slope = if c < ms {
extract_scalar(&sorted_medians.narrow(0, c, 1)?)?
} else {
extract_scalar(&sorted_medians.narrow(0, 0, 1)?)?
};
let high_slope = if ms > c {
extract_scalar(&sorted_medians.narrow(0, ms - 1 - c, 1)?)?
} else {
extract_scalar(&sorted_medians.narrow(0, ms - 1, 1)?)?
};
Ok(RobustRegressionResult {
slope: Tensor::<R>::full_scalar(&[], dtype, slope, device),
intercept: Tensor::<R>::full_scalar(&[], dtype, intercept, device),
low_slope: Tensor::<R>::full_scalar(&[], dtype, low_slope, device),
high_slope: Tensor::<R>::full_scalar(&[], dtype, high_slope, device),
})
}