use num_traits::Float;
use crate::math::scaling::ScalingMethod;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum RobustnessMethod {
#[default]
Bisquare,
Huber,
Talwar,
}
impl RobustnessMethod {
const DEFAULT_BISQUARE_C: f64 = 6.0;
const DEFAULT_HUBER_C: f64 = 1.345;
const DEFAULT_TALWAR_C: f64 = 2.5;
const SCALE_THRESHOLD: f64 = 1e-7;
const MIN_TUNED_SCALE: f64 = 1e-12;
pub fn apply_robustness_weights<T: Float>(
&self,
residuals: &[T],
weights: &mut [T],
scaling_method: ScalingMethod,
scratch: &mut [T],
) {
if residuals.is_empty() {
return;
}
let base_scale = self.compute_scale(residuals, scaling_method, scratch);
let (method_type, tuning_constant) = match self {
Self::Bisquare => (0, Self::DEFAULT_BISQUARE_C),
Self::Huber => (1, Self::DEFAULT_HUBER_C),
Self::Talwar => (2, Self::DEFAULT_TALWAR_C),
};
let c_t = T::from(tuning_constant).unwrap();
for (i, &r) in residuals.iter().enumerate() {
weights[i] = match method_type {
0 => Self::bisquare_weight(r, base_scale, c_t),
1 => Self::huber_weight(r, base_scale, c_t),
_ => Self::talwar_weight(r, base_scale, c_t),
};
}
}
fn compute_scale<T: Float>(
&self,
residuals: &[T],
scaling_method: ScalingMethod,
scratch: &mut [T],
) -> T {
let n = residuals.len();
if n == 0 {
return T::zero();
}
let mut sum_abs = T::zero();
for &r in residuals {
sum_abs = sum_abs + r.abs();
}
let mae = sum_abs / T::from(n).unwrap();
if mae.is_zero() {
return T::zero();
}
let relative_threshold = T::from(Self::SCALE_THRESHOLD).unwrap() * mae;
let absolute_threshold = T::from(Self::MIN_TUNED_SCALE).unwrap();
let scale_threshold = relative_threshold.max(absolute_threshold);
scratch.copy_from_slice(residuals);
let scale_val = scaling_method.compute(scratch);
if scale_val <= scale_threshold {
mae.max(scale_val)
} else {
scale_val
}
}
#[inline]
pub(crate) fn bisquare_weight<T: Float>(residual: T, scale: T, c: T) -> T {
if scale <= T::zero() {
return T::one();
}
let min_eps = T::from(Self::MIN_TUNED_SCALE).unwrap();
let c_clamped = c.max(min_eps);
let tuned_scale = (scale * c_clamped).max(min_eps);
let u = (residual / tuned_scale).abs();
let low_threshold = T::from(0.001).unwrap();
let high_threshold = T::from(0.999).unwrap();
if u >= high_threshold {
T::zero()
} else if u <= low_threshold {
T::one()
} else {
let tmp = T::one() - u * u;
tmp * tmp
}
}
#[inline]
pub(crate) fn huber_weight<T: Float>(residual: T, scale: T, c: T) -> T {
if scale <= T::zero() {
return T::one();
}
let u = (residual / scale).abs();
if u <= c { T::one() } else { c / u }
}
#[inline]
pub(crate) fn talwar_weight<T: Float>(residual: T, scale: T, c: T) -> T {
if scale <= T::zero() {
return T::one();
}
let u = (residual / scale).abs();
if u <= c { T::one() } else { T::zero() }
}
}