#[cfg(feature = "cpu")]
use crate::engine::executor::{smooth_pass_parallel, vertex_pass_parallel};
#[cfg(feature = "cpu")]
use crate::evaluation::cv::cv_pass_parallel;
#[cfg(feature = "cpu")]
use crate::evaluation::intervals::interval_pass_parallel;
use num_traits::Float;
use std::fmt::Debug;
use std::result::Result;
use loess_rs::internals::adapters::batch::BatchLoessBuilder;
use loess_rs::internals::algorithms::regression::PolynomialDegree;
use loess_rs::internals::algorithms::regression::SolverLinalg;
use loess_rs::internals::algorithms::regression::ZeroWeightFallback;
use loess_rs::internals::algorithms::robustness::RobustnessMethod;
use loess_rs::internals::api::SurfaceMode;
use loess_rs::internals::engine::output::LoessResult;
use loess_rs::internals::evaluation::cv::{CVConfig, CVKind};
use loess_rs::internals::math::boundary::BoundaryPolicy;
use loess_rs::internals::math::distance::DistanceLinalg;
use loess_rs::internals::math::distance::DistanceMetric;
use loess_rs::internals::math::kernel::WeightFunction;
use loess_rs::internals::math::linalg::FloatLinalg;
use loess_rs::internals::math::scaling::ScalingMethod;
use loess_rs::internals::primitives::backend::Backend;
use loess_rs::internals::primitives::errors::LoessError;
use crate::input::LoessInput;
use crate::math::neighborhood::build_kdtree_parallel;
#[derive(Debug, Clone)]
pub struct ParallelBatchLoessBuilder<T: FloatLinalg + DistanceLinalg + SolverLinalg> {
pub base: BatchLoessBuilder<T>,
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Debug + Send + Sync> Default
for ParallelBatchLoessBuilder<T>
{
fn default() -> Self {
Self::new()
}
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Debug + Send + Sync>
ParallelBatchLoessBuilder<T>
{
fn new() -> Self {
let mut base = BatchLoessBuilder::default();
base.parallel = Some(true); Self { base }
}
pub fn parallel(mut self, parallel: bool) -> Self {
self.base.parallel = Some(parallel);
self
}
pub fn backend(mut self, backend: Backend) -> Self {
self.base.backend = Some(backend);
self
}
pub fn fraction(mut self, fraction: T) -> Self {
self.base.fraction = fraction;
self
}
pub fn iterations(mut self, iterations: usize) -> Self {
self.base.iterations = iterations;
self
}
pub fn weight_function(mut self, wf: WeightFunction) -> Self {
self.base.weight_function = wf;
self
}
pub fn robustness_method(mut self, method: RobustnessMethod) -> Self {
self.base.robustness_method = method;
self
}
pub fn scaling_method(mut self, method: ScalingMethod) -> Self {
self.base.scaling_method = method;
self
}
pub fn zero_weight_fallback(mut self, fallback: ZeroWeightFallback) -> Self {
self.base.zero_weight_fallback = fallback;
self
}
pub fn boundary_policy(mut self, policy: BoundaryPolicy) -> Self {
self.base.boundary_policy = policy;
self
}
pub fn polynomial_degree(mut self, degree: PolynomialDegree) -> Self {
self.base.polynomial_degree = degree;
self
}
pub fn dimensions(mut self, dims: usize) -> Self {
self.base.dimensions = dims;
self
}
pub fn distance_metric(mut self, metric: DistanceMetric<T>) -> Self {
self.base.distance_metric = metric;
self
}
pub fn surface_mode(mut self, mode: SurfaceMode) -> Self {
self.base.surface_mode = mode;
self
}
pub fn cell(mut self, cell: f64) -> Self {
self.base.cell = Some(cell);
self
}
pub fn boundary_degree_fallback(mut self, enabled: bool) -> Self {
self.base = self.base.boundary_degree_fallback(enabled);
self
}
pub fn interpolation_vertices(mut self, vertices: usize) -> Self {
self.base.interpolation_vertices = Some(vertices);
self
}
pub fn auto_converge(mut self, tolerance: T) -> Self {
self.base.auto_converge = Some(tolerance);
self
}
pub fn compute_residuals(mut self, enabled: bool) -> Self {
self.base.compute_residuals = enabled;
self
}
pub fn return_robustness_weights(mut self, enabled: bool) -> Self {
self.base.return_robustness_weights = enabled;
self
}
pub fn return_diagnostics(mut self, enabled: bool) -> Self {
self.base.return_diagnostics = enabled;
self
}
pub fn confidence_intervals(mut self, level: T) -> Self {
self.base = self.base.confidence_intervals(level);
self
}
pub fn prediction_intervals(mut self, level: T) -> Self {
self.base = self.base.prediction_intervals(level);
self
}
pub fn cross_validate(mut self, config: CVConfig<'_, T>) -> Self {
self.base.cv_fractions = Some(config.fractions().to_vec());
self.base.cv_kind = Some(config.kind());
self.base.cv_seed = config.get_seed();
self
}
pub fn cv_seed(mut self, seed: u64) -> Self {
self.base.cv_seed = Some(seed);
self
}
pub fn cv_kind(mut self, method: CVKind) -> Self {
self.base.cv_kind = Some(method);
self
}
pub fn return_se(mut self, enabled: bool) -> Self {
self.base = self.base.return_se(enabled);
self
}
pub fn build(self) -> Result<ParallelBatchLoess<T>, LoessError> {
if let Some(ref err) = self.base.deferred_error {
return Err(err.clone());
}
let _ = self.base.clone().build()?;
Ok(ParallelBatchLoess { config: self })
}
}
pub struct ParallelBatchLoess<T: FloatLinalg + DistanceLinalg + SolverLinalg> {
config: ParallelBatchLoessBuilder<T>,
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Float + Debug + Send + Sync + 'static>
ParallelBatchLoess<T>
{
pub fn fit<I1, I2>(self, x: &I1, y: &I2) -> Result<LoessResult<T>, LoessError>
where
I1: LoessInput<T> + ?Sized,
I2: LoessInput<T> + ?Sized,
{
let x_slice = x.as_loess_slice()?;
let y_slice = y.as_loess_slice()?;
let mut builder = self.config.base;
match builder.backend.unwrap_or(Backend::CPU) {
Backend::CPU => {
#[cfg(feature = "cpu")]
{
if builder.parallel.unwrap_or(true) {
builder.custom_smooth_pass = Some(smooth_pass_parallel);
builder.custom_cv_pass = Some(cv_pass_parallel);
builder.custom_interval_pass = Some(interval_pass_parallel);
builder.custom_vertex_pass = Some(vertex_pass_parallel);
builder.custom_kdtree_builder = Some(build_kdtree_parallel);
}
}
#[cfg(not(feature = "cpu"))]
{
builder.custom_smooth_pass = None;
builder.custom_cv_pass = None;
builder.custom_interval_pass = None;
builder.custom_vertex_pass = None;
}
}
Backend::GPU => {
return Err(LoessError::UnsupportedFeature {
adapter: "Batch",
feature: "GPU backend (not yet implemented)",
});
}
}
let processor = builder.build()?;
processor.fit(x_slice, y_slice)
}
}