#![allow(clippy::needless_range_loop)]
use crate::{error::Error, linalg::Matrix};
use super::{
types::{LoessFit, LoessOptions},
normalize::normalize_predictors,
predict::fit_at_point_impl,
robust::compute_biweight_weights,
weights::MIN_NEIGHBORS_QUADRATIC,
};
pub fn loess_fit(
y: &[f64],
x: &[Vec<f64>],
options: &LoessOptions,
) -> Result<LoessFit, Error> {
let n = y.len();
let p = x.len();
let min_required = if options.degree == 2 {
MIN_NEIGHBORS_QUADRATIC
} else {
2
};
if n < min_required {
return Err(Error::InsufficientData {
required: min_required,
available: n,
});
}
if p == 0 {
return Err(Error::InvalidInput(
"At least one predictor variable is required".to_string(),
));
}
for (i, x_var) in x.iter().enumerate() {
if x_var.len() != n {
return Err(Error::InvalidInput(format!(
"x[{}] has {} elements, expected {}",
i,
x_var.len(),
n
)));
}
}
if options.span <= 0.0 || options.span > 1.0 {
return Err(Error::InvalidInput(format!(
"Span must be in (0, 1], got {}",
options.span
)));
}
if options.degree > 2 {
return Err(Error::InvalidInput(
"Degree must be 0 (constant), 1 (linear), or 2 (quadratic)".to_string(),
));
}
let mut x_data = Vec::with_capacity(n * p);
for i in 0..n {
for j in 0..p {
x_data.push(x[j][i]);
}
}
let x_matrix = Matrix::new(n, p, x_data);
let (x_normalized, _normalization_info) = normalize_predictors(&x_matrix);
let mut robust_weights = vec![1.0; n];
let mut fitted = vec![0.0; n];
for iteration in 0..=options.robust_iterations {
for i in 0..n {
let mut query = Vec::with_capacity(p);
for j in 0..p {
query.push(x_normalized.get(i, j));
}
let robustness_weights = if iteration > 0 {
Some(robust_weights.as_slice())
} else {
None
};
let fitted_value = fit_at_point_impl(&query, &x_normalized, y, options, robustness_weights)?;
fitted[i] = fitted_value;
}
if iteration < options.robust_iterations {
let residuals: Vec<f64> = y.iter().zip(fitted.iter()).map(|(yi, fitted)| yi - fitted).collect();
let mean_abs_deviation: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
let mut abs_residuals: Vec<f64> = residuals.iter().map(|r| r.abs()).collect();
abs_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_index = n / 2;
let mad = if n.is_multiple_of(2) {
(abs_residuals[median_index - 1] + abs_residuals[median_index]) / 2.0
} else {
abs_residuals[median_index]
};
let consistent_mad = 6.0 * mad;
if consistent_mad >= 1e-7 * mean_abs_deviation {
robust_weights = compute_biweight_weights(&residuals);
} else {
}
}
}
Ok(LoessFit {
fitted,
predictions: None,
span: options.span,
degree: options.degree,
robust_iterations: options.robust_iterations,
surface: options.surface,
})
}