use crate::error::OptimizeResult;
use crate::result::OptimizeResults;
use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1};
pub trait RobustLoss: Clone {
fn loss(&self, r: f64) -> f64;
fn weight(&self, r: f64) -> f64;
fn weight_derivative(&self, r: f64) -> f64;
}
#[derive(Debug, Clone)]
pub struct SquaredLoss;
impl RobustLoss for SquaredLoss {
fn loss(&self, r: f64) -> f64 {
0.5 * r * r
}
fn weight(&self, r: f64) -> f64 {
1.0
}
fn weight_derivative(&self, r: f64) -> f64 {
0.0
}
}
#[derive(Debug, Clone)]
pub struct HuberLoss {
delta: f64,
}
impl HuberLoss {
pub fn new(delta: f64) -> Self {
assert!(delta > 0.0, "Delta must be positive");
HuberLoss { delta }
}
}
impl RobustLoss for HuberLoss {
fn loss(&self, r: f64) -> f64 {
let abs_r = r.abs();
if abs_r <= self.delta {
0.5 * r * r
} else {
self.delta * (abs_r - 0.5 * self.delta)
}
}
fn weight(&self, r: f64) -> f64 {
let abs_r = r.abs();
if abs_r < 1e-10 || abs_r <= self.delta {
1.0
} else {
self.delta / abs_r
}
}
fn weight_derivative(&self, r: f64) -> f64 {
let abs_r = r.abs();
if abs_r <= self.delta || abs_r < 1e-10 {
0.0
} else {
-self.delta / (abs_r * abs_r)
}
}
}
#[derive(Debug, Clone)]
pub struct BisquareLoss {
c: f64,
}
impl BisquareLoss {
pub fn new(c: f64) -> Self {
assert!(c > 0.0, "Tuning constant must be positive");
BisquareLoss { c }
}
}
impl RobustLoss for BisquareLoss {
fn loss(&self, r: f64) -> f64 {
let abs_r = r.abs();
if abs_r <= self.c {
let u = r / self.c;
(self.c * self.c / 6.0) * (1.0 - (1.0 - u * u).powi(3))
} else {
self.c * self.c / 6.0
}
}
fn weight(&self, r: f64) -> f64 {
let abs_r = r.abs();
if abs_r < 1e-10 {
1.0
} else if abs_r <= self.c {
let u = r / self.c;
(1.0 - u * u).powi(2)
} else {
0.0
}
}
fn weight_derivative(&self, r: f64) -> f64 {
let abs_r = r.abs();
if abs_r <= self.c && abs_r >= 1e-10 {
let u = r / self.c;
-4.0 * u * (1.0 - u * u) / (self.c * self.c)
} else {
0.0
}
}
}
#[derive(Debug, Clone)]
pub struct CauchyLoss {
c: f64,
}
impl CauchyLoss {
pub fn new(c: f64) -> Self {
assert!(c > 0.0, "Scale parameter must be positive");
CauchyLoss { c }
}
}
impl RobustLoss for CauchyLoss {
fn loss(&self, r: f64) -> f64 {
let u = r / self.c;
(self.c * self.c / 2.0) * (1.0 + u * u).ln()
}
fn weight(&self, r: f64) -> f64 {
if r.abs() < 1e-10 {
1.0
} else {
let u = r / self.c;
1.0 / (1.0 + u * u)
}
}
fn weight_derivative(&self, r: f64) -> f64 {
if r.abs() < 1e-10 {
0.0
} else {
let u = r / self.c;
let denom = 1.0 + u * u;
-2.0 * u / (self.c * self.c * denom * denom)
}
}
}
#[derive(Debug, Clone)]
pub struct RobustOptions {
pub max_iter: usize,
pub max_nfev: Option<usize>,
pub xtol: f64,
pub ftol: f64,
pub gtol: f64,
pub use_irls: bool,
pub weight_tol: f64,
pub irls_max_iter: usize,
}
impl Default for RobustOptions {
fn default() -> Self {
RobustOptions {
max_iter: 100,
max_nfev: None,
xtol: 1e-8,
ftol: 1e-8,
gtol: 1e-8,
use_irls: true,
weight_tol: 1e-4,
irls_max_iter: 20,
}
}
}
#[allow(dead_code)]
pub fn robust_least_squares<F, J, L, D, S1, S2>(
residuals: F,
x0: &ArrayBase<S1, Ix1>,
loss: L,
jacobian: Option<J>,
data: &ArrayBase<S2, Ix1>,
options: Option<RobustOptions>,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&[f64], &[D]) -> Array1<f64>,
J: Fn(&[f64], &[D]) -> Array2<f64>,
L: RobustLoss,
D: Clone,
S1: Data<Elem = f64>,
S2: Data<Elem = D>,
{
let options = options.unwrap_or_default();
if options.use_irls {
irls_optimizer(residuals, x0, loss, jacobian, data, &options)
} else {
gradient_based_robust_optimizer(residuals, x0, loss, jacobian, data, &options)
}
}
#[allow(dead_code)]
fn irls_optimizer<F, J, L, D, S1, S2>(
residuals: F,
x0: &ArrayBase<S1, Ix1>,
loss: L,
jacobian: Option<J>,
data: &ArrayBase<S2, Ix1>,
options: &RobustOptions,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&[f64], &[D]) -> Array1<f64>,
J: Fn(&[f64], &[D]) -> Array2<f64>,
L: RobustLoss,
D: Clone,
S1: Data<Elem = f64>,
S2: Data<Elem = D>,
{
let mut x = x0.to_owned();
let m = x.len();
let max_nfev = options.max_nfev.unwrap_or(options.max_iter * m * 10);
let mut nfev = 0;
let mut njev = 0;
let mut iter = 0;
let mut res = residuals(
x.as_slice().expect("Operation failed"),
data.as_slice().expect("Operation failed"),
);
nfev += 1;
let n = res.len();
let mut weights = Array1::ones(n);
let mut prev_weights = weights.clone();
let compute_numerical_jacobian =
|x_val: &Array1<f64>, res_val: &Array1<f64>| -> (Array2<f64>, usize) {
let eps = 1e-8;
let mut jac = Array2::zeros((n, m));
let mut count = 0;
for j in 0..m {
let mut x_h = x_val.clone();
x_h[j] += eps;
let res_h = residuals(
x_h.as_slice().expect("Operation failed"),
data.as_slice().expect("Operation failed"),
);
count += 1;
for i in 0..n {
jac[[i, j]] = (res_h[i] - res_val[i]) / eps;
}
}
(jac, count)
};
while iter < options.irls_max_iter && nfev < max_nfev {
for i in 0..n {
weights[i] = loss.weight(res[i]);
}
let weight_change = weights
.iter()
.zip(prev_weights.iter())
.map(|(&w, &pw)| (w - pw).abs())
.sum::<f64>()
/ n as f64;
if weight_change < options.weight_tol && iter > 0 {
break;
}
prev_weights = weights.clone();
let (jac, jac_evals) = match &jacobian {
Some(jac_fn) => {
let j = jac_fn(
x.as_slice().expect("Operation failed"),
data.as_slice().expect("Operation failed"),
);
njev += 1;
(j, 0)
}
None => {
let (j, count) = compute_numerical_jacobian(&x, &res);
nfev += count;
(j, count)
}
};
let mut weighted_jac = Array2::zeros((n, m));
let mut weighted_res = Array1::zeros(n);
for i in 0..n {
let w = weights[i].sqrt();
for j in 0..m {
weighted_jac[[i, j]] = jac[[i, j]] * w;
}
weighted_res[i] = res[i] * w;
}
let jt_wj = weighted_jac.t().dot(&weighted_jac);
let neg_jt_wr = -weighted_jac.t().dot(&weighted_res);
match solve(&jt_wj, &neg_jt_wr) {
Some(step) => {
let mut line_search_alpha = 1.0;
let best_cost = compute_robust_cost(&res, &loss);
let mut best_x = x.clone();
for _ in 0..10 {
let x_new = &x + &step * line_search_alpha;
let res_new = residuals(
x_new.as_slice().expect("Operation failed"),
data.as_slice().expect("Operation failed"),
);
nfev += 1;
let new_cost = compute_robust_cost(&res_new, &loss);
if new_cost < best_cost {
best_x = x_new;
break;
}
line_search_alpha *= 0.5;
}
let step_norm = step.iter().map(|&s| s * s).sum::<f64>().sqrt();
let x_norm = x.iter().map(|&xi| xi * xi).sum::<f64>().sqrt();
if step_norm < options.xtol * (1.0 + x_norm) {
x = best_x;
res = residuals(
x.as_slice().expect("Operation failed"),
data.as_slice().expect("Operation failed"),
);
nfev += 1;
break;
}
x = best_x;
res = residuals(
x.as_slice().expect("Operation failed"),
data.as_slice().expect("Operation failed"),
);
nfev += 1;
}
None => {
break;
}
}
iter += 1;
}
let final_cost = compute_robust_cost(&res, &loss);
let mut result = OptimizeResults::<f64>::default();
result.x = x;
result.fun = final_cost;
result.nfev = nfev;
result.njev = njev;
result.nit = iter;
result.success = iter < options.irls_max_iter;
if result.success {
result.message = "Optimization terminated successfully.".to_string();
} else {
result.message = "Maximum iterations reached.".to_string();
}
Ok(result)
}
#[allow(dead_code)]
fn gradient_based_robust_optimizer<F, J, L, D, S1, S2>(
_residuals: F,
x0: &ArrayBase<S1, Ix1>,
_loss: L,
_jacobian: Option<J>,
_data: &ArrayBase<S2, Ix1>,
_options: &RobustOptions,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&[f64], &[D]) -> Array1<f64>,
J: Fn(&[f64], &[D]) -> Array2<f64>,
L: RobustLoss,
D: Clone,
S1: Data<Elem = f64>,
S2: Data<Elem = D>,
{
let mut result = OptimizeResults::<f64>::default();
result.x = x0.to_owned();
result.fun = 0.0;
result.success = false;
result.message = "Gradient-based robust optimization not yet implemented".to_string();
Ok(result)
}
#[allow(dead_code)]
fn compute_robust_cost<L: RobustLoss>(residuals: &Array1<f64>, loss: &L) -> f64 {
residuals.iter().map(|&r| loss.loss(r)).sum()
}
#[allow(dead_code)]
fn solve(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
use scirs2_linalg::solve;
solve(&a.view(), &b.view(), None).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_huber_loss() {
let loss = HuberLoss::new(1.0);
assert!((loss.loss(0.5) - 0.125).abs() < 1e-10);
assert!((loss.weight(0.5) - 1.0).abs() < 1e-10);
assert!((loss.loss(2.0) - 1.5).abs() < 1e-10);
assert!((loss.weight(2.0) - 0.5).abs() < 1e-10);
}
#[test]
fn test_bisquare_loss() {
let loss = BisquareLoss::new(4.685);
let small_r = 1.0;
assert!(loss.loss(small_r) > 0.0);
assert!(loss.weight(small_r) > 0.0);
assert!(loss.weight(small_r) < 1.0);
let large_r = 5.0;
assert!((loss.loss(large_r) - loss.loss(10.0)).abs() < 1e-10);
assert_eq!(loss.weight(large_r), 0.0);
}
#[test]
fn test_cauchy_loss() {
let loss = CauchyLoss::new(1.0);
assert!(loss.weight(0.0) > loss.weight(1.0));
assert!(loss.weight(1.0) > loss.weight(2.0));
assert!(loss.weight(2.0) > loss.weight(5.0));
assert_eq!(loss.loss(1.0), loss.loss(-1.0));
assert_eq!(loss.weight(1.0), loss.weight(-1.0));
}
#[test]
fn test_robust_least_squares_linear() {
fn residual(x: &[f64], data: &[f64]) -> Array1<f64> {
let n = data.len() / 2;
let t_values = &data[0..n];
let y_values = &data[n..];
let params = x;
let mut res = Array1::zeros(n);
for i in 0..n {
res[i] = y_values[i] - (params[0] + params[1] * t_values[i]);
}
res
}
fn jacobian(x: &[f64], data: &[f64]) -> Array2<f64> {
let n = data.len() / 2;
let t_values = &data[0..n];
let mut jac = Array2::zeros((n, 2));
for i in 0..n {
jac[[i, 0]] = -1.0;
jac[[i, 1]] = -t_values[i];
}
jac
}
let x0 = array![0.0, 0.0];
let data_array = array![0.0, 1.0, 2.0, 3.0, 4.0, 0.1, 0.9, 2.1, 2.9, 10.0];
let huber_loss = HuberLoss::new(1.0);
let result =
robust_least_squares(residual, &x0, huber_loss, Some(jacobian), &data_array, None)
.expect("Operation failed");
println!("Result: {:?}", result);
assert!(result.success);
assert!((result.x[1] - 1.0).abs() < 0.5); }
#[test]
fn test_irls_convergence() {
fn residual(x: &[f64], _: &[f64]) -> Array1<f64> {
array![x[0] - 1.0, x[1] - 2.0]
}
fn jacobian(x: &[f64], _: &[f64]) -> Array2<f64> {
array![[1.0, 0.0], [0.0, 1.0]]
}
let x0 = array![0.0, 0.0];
let data = array![];
let huber_loss = HuberLoss::new(1.0);
let result = robust_least_squares(residual, &x0, huber_loss, Some(jacobian), &data, None)
.expect("Operation failed");
assert!(result.success);
assert!((result.x[0] - 1.0).abs() < 1e-3);
assert!((result.x[1] - 2.0).abs() < 1e-3);
}
}