#[cfg(feature = "cpu")]
use crate::engine::executor::{smooth_pass_parallel, vertex_pass_parallel};
#[cfg(feature = "cpu")]
use crate::evaluation::cv::cv_pass_parallel;
#[cfg(feature = "cpu")]
use crate::evaluation::intervals::interval_pass_parallel;
use num_traits::Float;
use std::fmt::Debug;
use std::result::Result;
use crate::math::neighborhood::build_kdtree_parallel;
use loess_rs::internals::adapters::online::{OnlineLoessBuilder, OnlineOutput, UpdateMode};
use loess_rs::internals::algorithms::regression::PolynomialDegree;
use loess_rs::internals::algorithms::regression::SolverLinalg;
use loess_rs::internals::algorithms::regression::ZeroWeightFallback;
use loess_rs::internals::algorithms::robustness::RobustnessMethod;
use loess_rs::internals::engine::executor::SurfaceMode;
use loess_rs::internals::math::boundary::BoundaryPolicy;
use loess_rs::internals::math::distance::DistanceLinalg;
use loess_rs::internals::math::distance::DistanceMetric;
use loess_rs::internals::math::kernel::WeightFunction;
use loess_rs::internals::math::linalg::FloatLinalg;
use loess_rs::internals::math::scaling::ScalingMethod;
use loess_rs::internals::primitives::backend::Backend;
use loess_rs::internals::primitives::errors::LoessError;
#[derive(Debug, Clone)]
pub struct ParallelOnlineLoessBuilder<T: FloatLinalg + DistanceLinalg + SolverLinalg> {
pub base: OnlineLoessBuilder<T>,
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Debug + Send + Sync> Default
for ParallelOnlineLoessBuilder<T>
{
fn default() -> Self {
Self::new()
}
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Debug + Send + Sync>
ParallelOnlineLoessBuilder<T>
{
fn new() -> Self {
let mut base = OnlineLoessBuilder::default();
base.parallel = Some(false);
Self { base }
}
pub fn parallel(mut self, parallel: bool) -> Self {
self.base.parallel = Some(parallel);
self
}
pub fn backend(mut self, backend: Backend) -> Self {
self.base.backend = Some(backend);
self
}
pub fn fraction(mut self, fraction: T) -> Self {
self.base.fraction = fraction;
self
}
pub fn iterations(mut self, iterations: usize) -> Self {
self.base.iterations = iterations;
self
}
pub fn weight_function(mut self, wf: WeightFunction) -> Self {
self.base.weight_function = wf;
self
}
pub fn robustness_method(mut self, method: RobustnessMethod) -> Self {
self.base.robustness_method = method;
self
}
pub fn scaling_method(mut self, method: ScalingMethod) -> Self {
self.base.scaling_method = method;
self
}
pub fn zero_weight_fallback(mut self, fallback: ZeroWeightFallback) -> Self {
self.base.zero_weight_fallback = fallback;
self
}
pub fn boundary_policy(mut self, policy: BoundaryPolicy) -> Self {
self.base.boundary_policy = policy;
self
}
pub fn polynomial_degree(mut self, degree: PolynomialDegree) -> Self {
self.base.polynomial_degree = degree;
self
}
pub fn dimensions(mut self, dims: usize) -> Self {
self.base.dimensions = dims;
self
}
pub fn distance_metric(mut self, metric: DistanceMetric<T>) -> Self {
self.base.distance_metric = metric;
self
}
pub fn surface_mode(mut self, mode: SurfaceMode) -> Self {
self.base.surface_mode = mode;
self
}
pub fn cell(mut self, cell: f64) -> Self {
self.base.cell = Some(cell);
self
}
pub fn interpolation_vertices(mut self, vertices: usize) -> Self {
self.base.interpolation_vertices = Some(vertices);
self
}
pub fn boundary_degree_fallback(mut self, enabled: bool) -> Self {
self.base = self.base.boundary_degree_fallback(enabled);
self
}
pub fn auto_converge(mut self, tolerance: T) -> Self {
self.base.auto_converge = Some(tolerance);
self
}
pub fn compute_residuals(mut self, compute: bool) -> Self {
self.base.compute_residuals = compute;
self
}
pub fn return_robustness_weights(mut self, ret: bool) -> Self {
self.base.return_robustness_weights = ret;
self
}
pub fn window_capacity(mut self, capacity: usize) -> Self {
self.base.window_capacity = capacity;
self
}
pub fn min_points(mut self, min: usize) -> Self {
self.base.min_points = min;
self
}
pub fn update_mode(mut self, mode: UpdateMode) -> Self {
self.base.update_mode = mode;
self
}
pub fn build(self) -> Result<ParallelOnlineLoess<T>, LoessError> {
if let Some(ref err) = self.base.deferred_error {
return Err(err.clone());
}
let mut builder = self.base;
#[cfg(feature = "cpu")]
{
if builder.parallel.unwrap_or(true) {
builder.custom_smooth_pass = Some(smooth_pass_parallel);
builder.custom_cv_pass = Some(cv_pass_parallel);
builder.custom_interval_pass = Some(interval_pass_parallel);
builder.custom_vertex_pass = Some(vertex_pass_parallel);
builder.custom_kdtree_builder = Some(build_kdtree_parallel);
}
}
let processor = builder.build()?;
Ok(ParallelOnlineLoess { processor })
}
}
pub struct ParallelOnlineLoess<T: FloatLinalg + DistanceLinalg + SolverLinalg> {
processor: loess_rs::internals::adapters::online::OnlineLoess<T>,
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Float + Debug + Send + Sync + 'static>
ParallelOnlineLoess<T>
{
pub fn add_point(&mut self, x: &[T], y: T) -> Result<Option<OnlineOutput<T>>, LoessError> {
self.processor.add_point(x, y)
}
pub fn window_size(&self) -> usize {
self.processor.window_size()
}
pub fn reset(&mut self) {
self.processor.reset();
}
}