#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use core::fmt::Debug;
use num_traits::Float;
#[cfg(feature = "std")]
use std::vec::Vec;
use crate::algorithms::interpolation::calculate_delta;
use crate::algorithms::regression::{WLSSolver, ZeroWeightFallback};
use crate::algorithms::robustness::RobustnessMethod;
use crate::engine::executor::{CVPassFn, FitPassFn, IntervalPassFn, SmoothPassFn};
use crate::engine::executor::{LowessConfig, LowessExecutor};
use crate::engine::output::LowessResult;
use crate::engine::validator::Validator;
use crate::evaluation::cv::CVKind;
use crate::evaluation::diagnostics::Diagnostics;
use crate::evaluation::intervals::IntervalMethod;
use crate::math::boundary::BoundaryPolicy;
use crate::math::kernel::WeightFunction;
use crate::math::scaling::ScalingMethod;
use crate::primitives::backend::Backend;
use crate::primitives::errors::LowessError;
use crate::primitives::sorting::{SortedData, sort_by_x, unsort};
#[derive(Debug, Clone)]
pub struct BatchLowessBuilder<T: Float> {
pub fraction: T,
pub iterations: usize,
pub delta: Option<T>,
pub weight_function: WeightFunction,
pub robustness_method: RobustnessMethod,
pub interval_type: Option<IntervalMethod<T>>,
pub cv_fractions: Option<Vec<T>>,
pub cv_kind: Option<CVKind>,
pub cv_seed: Option<u64>,
pub deferred_error: Option<LowessError>,
pub auto_convergence: Option<T>,
pub return_diagnostics: bool,
pub compute_residuals: bool,
pub return_robustness_weights: bool,
pub zero_weight_fallback: ZeroWeightFallback,
pub boundary_policy: BoundaryPolicy,
pub scaling_method: ScalingMethod,
#[doc(hidden)]
pub custom_smooth_pass: Option<SmoothPassFn<T>>,
#[doc(hidden)]
pub custom_cv_pass: Option<CVPassFn<T>>,
#[doc(hidden)]
pub custom_interval_pass: Option<IntervalPassFn<T>>,
#[doc(hidden)]
pub custom_fit_pass: Option<FitPassFn<T>>,
#[doc(hidden)]
pub backend: Option<Backend>,
#[doc(hidden)]
pub parallel: Option<bool>,
#[doc(hidden)]
pub delegate_boundary_handling: bool,
#[doc(hidden)]
pub(crate) duplicate_param: Option<&'static str>,
}
impl<T: Float> Default for BatchLowessBuilder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> BatchLowessBuilder<T> {
fn new() -> Self {
Self {
fraction: T::from(0.67).unwrap(),
iterations: 3,
delta: None,
weight_function: WeightFunction::default(),
robustness_method: RobustnessMethod::default(),
interval_type: None,
cv_fractions: None,
cv_kind: None,
cv_seed: None,
deferred_error: None,
auto_convergence: None,
return_diagnostics: false,
compute_residuals: false,
return_robustness_weights: false,
zero_weight_fallback: ZeroWeightFallback::default(),
boundary_policy: BoundaryPolicy::default(),
scaling_method: ScalingMethod::default(),
custom_smooth_pass: None,
custom_cv_pass: None,
custom_interval_pass: None,
custom_fit_pass: None,
backend: None,
delegate_boundary_handling: false,
parallel: None,
duplicate_param: None,
}
}
pub fn fraction(mut self, fraction: T) -> Self {
self.fraction = fraction;
self
}
pub fn iterations(mut self, iterations: usize) -> Self {
self.iterations = iterations;
self
}
pub fn delta(mut self, delta: T) -> Self {
self.delta = Some(delta);
self
}
pub fn weight_function(mut self, wf: WeightFunction) -> Self {
self.weight_function = wf;
self
}
pub fn robustness_method(mut self, method: RobustnessMethod) -> Self {
self.robustness_method = method;
self
}
pub fn zero_weight_fallback(mut self, fallback: ZeroWeightFallback) -> Self {
self.zero_weight_fallback = fallback;
self
}
pub fn boundary_policy(mut self, policy: BoundaryPolicy) -> Self {
self.boundary_policy = policy;
self
}
pub fn auto_converge(mut self, tolerance: T) -> Self {
self.auto_convergence = Some(tolerance);
self
}
pub fn compute_residuals(mut self, enabled: bool) -> Self {
self.compute_residuals = enabled;
self
}
pub fn return_robustness_weights(mut self, enabled: bool) -> Self {
self.return_robustness_weights = enabled;
self
}
pub fn return_diagnostics(mut self, enabled: bool) -> Self {
self.return_diagnostics = enabled;
self
}
pub fn confidence_intervals(mut self, level: T) -> Self {
self.interval_type = Some(IntervalMethod::confidence(level));
self
}
pub fn prediction_intervals(mut self, level: T) -> Self {
self.interval_type = Some(IntervalMethod::prediction(level));
self
}
pub fn cross_validate(mut self, fractions: Vec<T>) -> Self {
self.cv_fractions = Some(fractions);
self
}
pub fn cv_kind(mut self, kind: CVKind) -> Self {
self.cv_kind = Some(kind);
self
}
#[doc(hidden)]
pub fn backend(mut self, backend: Backend) -> Self {
self.backend = Some(backend);
self
}
#[doc(hidden)]
pub fn parallel(mut self, parallel: bool) -> Self {
self.parallel = Some(parallel);
self
}
#[doc(hidden)]
pub fn custom_smooth_pass(mut self, pass: SmoothPassFn<T>) -> Self {
self.custom_smooth_pass = Some(pass);
self
}
#[doc(hidden)]
pub fn custom_cv_pass(mut self, pass: CVPassFn<T>) -> Self {
self.custom_cv_pass = Some(pass);
self
}
#[doc(hidden)]
pub fn custom_interval_pass(mut self, pass: IntervalPassFn<T>) -> Self {
self.custom_interval_pass = Some(pass);
self
}
pub fn build(self) -> Result<BatchLowess<T>, LowessError> {
if let Some(err) = self.deferred_error {
return Err(err);
}
Validator::validate_no_duplicates(self.duplicate_param)?;
Validator::validate_fraction(self.fraction)?;
Validator::validate_iterations(self.iterations)?;
if let Some(delta) = self.delta {
Validator::validate_delta(delta)?;
}
if let Some(ref method) = self.interval_type {
Validator::validate_interval_level(method.level)?;
}
if let Some(ref fracs) = self.cv_fractions {
Validator::validate_cv_fractions(fracs)?;
}
if let Some(CVKind::KFold(k)) = self.cv_kind {
Validator::validate_kfold(k)?;
}
if let Some(tol) = self.auto_convergence {
Validator::validate_tolerance(tol)?;
}
Ok(BatchLowess { config: self })
}
}
pub struct BatchLowess<T: Float> {
config: BatchLowessBuilder<T>,
}
impl<T: Float + WLSSolver + Debug + Send + Sync + 'static> BatchLowess<T> {
pub fn fit(self, x: &[T], y: &[T]) -> Result<LowessResult<T>, LowessError> {
Validator::validate_inputs(x, y)?;
let sorted = if self.config.backend == Some(Backend::GPU) {
SortedData {
x: x.to_vec(),
y: y.to_vec(),
indices: (0..x.len()).collect(),
}
} else {
sort_by_x(x, y)
};
let delta = calculate_delta(self.config.delta, &sorted.x)?;
let zw_flag: u8 = self.config.zero_weight_fallback.to_u8();
let config = LowessConfig {
fraction: Some(self.config.fraction),
iterations: self.config.iterations,
delta,
weight_function: self.config.weight_function,
zero_weight_fallback: zw_flag,
robustness_method: self.config.robustness_method,
cv_fractions: self.config.cv_fractions,
cv_kind: self.config.cv_kind,
auto_convergence: self.config.auto_convergence,
return_variance: self.config.interval_type,
boundary_policy: self.config.boundary_policy,
scaling_method: self.config.scaling_method,
cv_seed: self.config.cv_seed,
custom_smooth_pass: self.config.custom_smooth_pass,
custom_cv_pass: self.config.custom_cv_pass,
custom_interval_pass: self.config.custom_interval_pass,
custom_fit_pass: self.config.custom_fit_pass,
parallel: self.config.parallel.unwrap_or(false),
backend: self.config.backend,
delegate_boundary_handling: self.config.delegate_boundary_handling,
};
let result = LowessExecutor::run_with_config(&sorted.x, &sorted.y, config)?;
let y_smooth = result.smoothed;
let std_errors = result.std_errors;
let iterations_used = result.iterations;
let fraction_used = result.used_fraction;
let cv_scores = result.cv_scores;
let residuals: Vec<T> = if let Some(r) = result.residuals {
r
} else {
sorted
.y
.iter()
.zip(y_smooth.iter())
.map(|(&orig, &smoothed_val)| orig - smoothed_val)
.collect()
};
let rob_weights = if self.config.return_robustness_weights {
result.robustness_weights
} else {
Vec::new()
};
let diagnostics = if self.config.return_diagnostics {
Some(Diagnostics::compute(
&sorted.y,
&y_smooth,
&residuals,
std_errors.as_deref(),
))
} else {
None
};
let (conf_lower, conf_upper, pred_lower, pred_upper) =
match (&self.config.interval_type, &std_errors) {
(Some(method), Some(se)) => {
if result.confidence_lower.is_some() || result.prediction_lower.is_some() {
(
result.confidence_lower,
result.confidence_upper,
result.prediction_lower,
result.prediction_upper,
)
} else {
let (cl, cu, pl, pu) =
method.compute_intervals(&y_smooth, se, &residuals)?;
(cl, cu, pl, pu)
}
}
_ => (None, None, None, None),
};
let indices = &sorted.indices;
let y_smooth_out = unsort(&y_smooth, indices);
let std_errors_out = std_errors.as_ref().map(|se| unsort(se, indices));
let residuals_out = if self.config.compute_residuals {
Some(unsort(&residuals, indices))
} else {
None
};
let rob_weights_out = if self.config.return_robustness_weights {
Some(unsort(&rob_weights, indices))
} else {
None
};
let cl_out = conf_lower.as_ref().map(|v| unsort(v, indices));
let cu_out = conf_upper.as_ref().map(|v| unsort(v, indices));
let pl_out = pred_lower.as_ref().map(|v| unsort(v, indices));
let pu_out = pred_upper.as_ref().map(|v| unsort(v, indices));
Ok(LowessResult {
x: x.to_vec(),
y: y_smooth_out,
standard_errors: std_errors_out,
confidence_lower: cl_out,
confidence_upper: cu_out,
prediction_lower: pl_out,
prediction_upper: pu_out,
residuals: residuals_out,
robustness_weights: rob_weights_out,
fraction_used,
iterations_used,
cv_scores,
diagnostics,
})
}
}