use scirs2_core::ndarray::{Array1, ArrayBase, Data, Dimension};
use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
use super::{check_non_negative, check_positive, check_sameshape};
use crate::error::{MetricsError, Result};
#[allow(dead_code)]
pub fn mean_poisson_deviance<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
{
check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
check_non_negative::<F, S1, S2, D1, D2>(y_true, y_pred)?;
let n_samples = y_true.len();
let mut deviance_sum = F::zero();
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
if *yp < F::epsilon() {
return Err(MetricsError::InvalidInput(
"Predicted values must be positive for Poisson deviance".to_string(),
));
}
if *yt > F::epsilon() {
deviance_sum = deviance_sum + *yt * (*yt / *yp).ln() - (*yt - *yp);
} else {
deviance_sum = deviance_sum + *yp;
}
}
Ok(
F::from(2.0).expect("Failed to convert constant to float") * deviance_sum
/ NumCast::from(n_samples).expect("Operation failed"),
)
}
#[allow(dead_code)]
pub fn mean_gamma_deviance<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
{
check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
check_positive::<F, S1, S2, D1, D2>(y_true, y_pred)?;
let n_samples = y_true.len();
let mut deviance_sum = F::zero();
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
deviance_sum = deviance_sum - (*yt / *yp).ln() + (*yt - *yp) / *yp;
}
Ok(
F::from(2.0).expect("Failed to convert constant to float") * deviance_sum
/ NumCast::from(n_samples).expect("Operation failed"),
)
}
#[allow(dead_code)]
pub fn tweedie_deviance_score<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
power: F,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + FromPrimitive,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
{
check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
if power.abs() < F::epsilon() {
let mut sum_squared_error = F::zero();
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
let error = *yt - *yp;
sum_squared_error = sum_squared_error + error * error;
}
return Ok(sum_squared_error / NumCast::from(y_true.len()).expect("Operation failed"));
} else if (power - F::one()).abs() < F::epsilon() {
return mean_poisson_deviance(y_true, y_pred);
} else if (power - F::from(2.0).expect("Failed to convert constant to float")).abs()
< F::epsilon()
{
return mean_gamma_deviance(y_true, y_pred);
}
let n_samples = y_true.len();
let mut deviance_sum = F::zero();
let two = F::from(2.0).expect("Failed to convert constant to float");
if power < F::one() {
} else if power < two {
for &val in y_true.iter() {
if val < F::zero() {
return Err(MetricsError::InvalidInput(
"y_true contains negative values, which is not allowed for this power parameter".to_string(),
));
}
}
for &val in y_pred.iter() {
if val <= F::zero() {
return Err(MetricsError::InvalidInput(
"y_pred contains non-positive values, which is not allowed for this power parameter".to_string(),
));
}
}
} else {
check_positive::<F, S1, S2, D1, D2>(y_true, y_pred)?;
}
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
let term1 = if (*yt).abs() < F::epsilon() {
F::zero() } else {
*yt * ((*yt).powf(F::one() - power) - (*yp).powf(F::one() - power)) / (F::one() - power)
};
let term2 = (*yp).powf(two - power) - (*yt).powf(two - power) / (two - power);
deviance_sum = deviance_sum + two * (term1 - term2);
}
Ok(deviance_sum / NumCast::from(n_samples).expect("Operation failed"))
}
#[allow(dead_code)]
pub fn quantile_loss<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
quantile: F,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + std::fmt::Display + FromPrimitive,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
{
check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
if quantile <= F::zero() || quantile >= F::one() {
return Err(MetricsError::InvalidInput(format!(
"Quantile must be between 0 and 1, got {}",
quantile
)));
}
let n_samples = y_true.len();
let mut loss_sum = F::zero();
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
let error = *yt - *yp;
if error >= F::zero() {
loss_sum = loss_sum + quantile * error;
} else {
loss_sum = loss_sum + (quantile - F::one()) * error;
}
}
Ok(loss_sum / NumCast::from(n_samples).expect("Operation failed"))
}
#[allow(dead_code)]
pub fn compute_robust_weights<F, S, D>(
residuals: &ArrayBase<S, D>,
method: &str,
tuning: Option<F>,
) -> Result<Array1<F>>
where
F: Float + NumCast + std::fmt::Debug + FromPrimitive,
S: Data<Elem = F>,
D: Dimension,
{
let n = residuals.len();
if n == 0 {
return Err(MetricsError::InvalidInput(
"Empty residuals array".to_string(),
));
}
let abs_residuals: Vec<F> = residuals.iter().map(|&r| r.abs()).collect();
let mut sorted_abs = abs_residuals.clone();
sorted_abs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_idx = n / 2;
let mad = if n.is_multiple_of(2) {
(sorted_abs[median_idx - 1] + sorted_abs[median_idx])
/ F::from(2.0).expect("Failed to convert constant to float")
} else {
sorted_abs[median_idx]
};
let scale = mad / F::from(0.6745).expect("Failed to convert constant to float");
let s = if scale > F::epsilon() {
scale
} else {
F::from(1e-8).expect("Failed to convert constant to float")
};
let c = match tuning {
Some(t) => t,
None => match method {
"huber" => F::from(1.345).expect("Failed to convert constant to float"),
"bisquare" => F::from(4.685).expect("Failed to convert constant to float"),
"cauchy" => F::from(2.385).expect("Failed to convert constant to float"),
_ => F::from(1.345).expect("Failed to convert constant to float"), },
};
let mut weights = Array1::<F>::zeros(n);
for (i, &r) in abs_residuals.iter().enumerate() {
let u = r / (c * s);
weights[i] = match method {
"huber" => {
if u <= F::one() {
F::one()
} else {
F::one() / u
}
}
"bisquare" => {
if u < F::one() {
let temp = F::one() - u * u;
temp * temp
} else {
F::zero()
}
}
"cauchy" => F::one() / (F::one() + u * u),
_ => {
return Err(MetricsError::InvalidInput(format!(
"Unknown weight method: {}. Valid options are 'huber', 'bisquare', 'cauchy'.",
method
)));
}
};
}
Ok(weights)
}
#[allow(dead_code)]
pub fn weighted_mean_squared_error<F, S1, S2, S3, D1, D2, D3>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
weights: &ArrayBase<S3, D3>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
S3: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
D3: Dimension,
{
check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
let n_samples = y_true.len();
if weights.len() != n_samples {
return Err(MetricsError::InvalidInput(format!(
"Weights length ({}) must match y_true length ({})",
weights.len(),
n_samples
)));
}
let mut weighted_error_sum = F::zero();
let mut weight_sum = F::zero();
for ((yt, yp), &w) in y_true.iter().zip(y_pred.iter()).zip(weights.iter()) {
if w < F::zero() {
return Err(MetricsError::InvalidInput(
"Weights must be non-negative".to_string(),
));
}
let error = *yt - *yp;
weighted_error_sum = weighted_error_sum + w * error * error;
weight_sum = weight_sum + w;
}
if weight_sum <= F::epsilon() {
return Err(MetricsError::InvalidInput(
"Sum of weights is zero".to_string(),
));
}
Ok(weighted_error_sum / weight_sum)
}
#[allow(dead_code)]
pub fn weighted_median_absolute_error<F, S1, S2, S3, D1, D2, D3>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
weights: &ArrayBase<S3, D3>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
S3: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
D3: Dimension,
{
check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
let n_samples = y_true.len();
if weights.len() != n_samples {
return Err(MetricsError::InvalidInput(format!(
"Weights length ({}) must match y_true length ({})",
weights.len(),
n_samples
)));
}
let mut abs_errors = Vec::with_capacity(n_samples);
let mut valid_weights = Vec::with_capacity(n_samples);
let mut weight_sum = F::zero();
for ((yt, yp), &w) in y_true.iter().zip(y_pred.iter()).zip(weights.iter()) {
if w < F::zero() {
return Err(MetricsError::InvalidInput(
"Weights must be non-negative".to_string(),
));
}
let error = (*yt - *yp).abs();
abs_errors.push(error);
valid_weights.push(w);
weight_sum = weight_sum + w;
}
if weight_sum <= F::epsilon() {
return Err(MetricsError::InvalidInput(
"Sum of weights is zero".to_string(),
));
}
let mut error_weight_pairs: Vec<(F, F)> = abs_errors.into_iter().zip(valid_weights).collect();
error_weight_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let half_weight = weight_sum / F::from(2.0).expect("Failed to convert constant to float");
let mut cumulative_weight = F::zero();
for (error, weight) in &error_weight_pairs {
cumulative_weight = cumulative_weight + *weight;
if cumulative_weight >= half_weight {
return Ok(*error);
}
}
Ok(error_weight_pairs.last().expect("Operation failed").0)
}
#[allow(dead_code)]
pub fn m_estimator<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
loss_function: &str,
tuning: Option<F>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug + FromPrimitive,
S1: Data<Elem = F>,
S2: Data<Elem = F>,
D1: Dimension,
D2: Dimension,
{
check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
let n_samples = y_true.len();
let mut residuals = Vec::with_capacity(n_samples);
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
residuals.push(*yt - *yp);
}
let weights =
compute_robust_weights(&Array1::from_vec(residuals.clone()), loss_function, tuning)?;
let mut loss_sum = F::zero();
let mut weight_sum = F::zero();
for (i, &residual) in residuals.iter().enumerate() {
let w = weights[i];
loss_sum = loss_sum + w * residual * residual;
weight_sum = weight_sum + w;
}
if weight_sum <= F::epsilon() {
return Err(MetricsError::InvalidInput(
"Sum of weights is zero".to_string(),
));
}
Ok(loss_sum / weight_sum)
}