#[cfg(feature = "cpu")]
use crate::engine::executor::{smooth_pass_parallel, vertex_pass_parallel};
#[cfg(feature = "cpu")]
use crate::evaluation::cv::cv_pass_parallel;
#[cfg(feature = "cpu")]
use crate::evaluation::intervals::interval_pass_parallel;
use num_traits::Float;
use std::fmt::Debug;
use std::result::Result;
use loess_rs::internals::adapters::streaming::{MergeStrategy, StreamingLoessBuilder};
use loess_rs::internals::algorithms::regression::PolynomialDegree;
use loess_rs::internals::algorithms::regression::SolverLinalg;
use loess_rs::internals::algorithms::regression::ZeroWeightFallback;
use loess_rs::internals::algorithms::robustness::RobustnessMethod;
use loess_rs::internals::engine::executor::SurfaceMode;
use loess_rs::internals::engine::output::LoessResult;
use loess_rs::internals::math::boundary::BoundaryPolicy;
use loess_rs::internals::math::distance::DistanceLinalg;
use loess_rs::internals::math::distance::DistanceMetric;
use loess_rs::internals::math::kernel::WeightFunction;
use loess_rs::internals::math::linalg::FloatLinalg;
use loess_rs::internals::math::scaling::ScalingMethod;
use loess_rs::internals::primitives::backend::Backend;
use loess_rs::internals::primitives::errors::LoessError;
use crate::input::LoessInput;
use crate::math::neighborhood::build_kdtree_parallel;
#[derive(Debug, Clone)]
pub struct ParallelStreamingLoessBuilder<T: FloatLinalg + DistanceLinalg + SolverLinalg> {
pub base: StreamingLoessBuilder<T>,
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Debug + Send + Sync> Default
for ParallelStreamingLoessBuilder<T>
{
fn default() -> Self {
Self::new()
}
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Debug + Send + Sync>
ParallelStreamingLoessBuilder<T>
{
fn new() -> Self {
let mut base = StreamingLoessBuilder::default();
base.parallel = Some(true); Self { base }
}
pub fn parallel(mut self, parallel: bool) -> Self {
self.base.parallel = Some(parallel);
self
}
pub fn backend(mut self, backend: Backend) -> Self {
self.base.backend = Some(backend);
self
}
pub fn fraction(mut self, fraction: T) -> Self {
self.base.fraction = fraction;
self
}
pub fn iterations(mut self, iterations: usize) -> Self {
self.base.iterations = iterations;
self
}
pub fn weight_function(mut self, wf: WeightFunction) -> Self {
self.base.weight_function = wf;
self
}
pub fn robustness_method(mut self, method: RobustnessMethod) -> Self {
self.base.robustness_method = method;
self
}
pub fn scaling_method(mut self, method: ScalingMethod) -> Self {
self.base.scaling_method = method;
self
}
pub fn zero_weight_fallback(mut self, fallback: ZeroWeightFallback) -> Self {
self.base.zero_weight_fallback = fallback;
self
}
pub fn boundary_policy(mut self, policy: BoundaryPolicy) -> Self {
self.base.boundary_policy = policy;
self
}
pub fn polynomial_degree(mut self, degree: PolynomialDegree) -> Self {
self.base.polynomial_degree = degree;
self
}
pub fn dimensions(mut self, dims: usize) -> Self {
self.base.dimensions = dims;
self
}
pub fn distance_metric(mut self, metric: DistanceMetric<T>) -> Self {
self.base.distance_metric = metric;
self
}
pub fn surface_mode(mut self, mode: SurfaceMode) -> Self {
self.base.surface_mode = mode;
self
}
pub fn cell(mut self, cell: f64) -> Self {
self.base.cell = Some(cell);
self
}
pub fn interpolation_vertices(mut self, vertices: usize) -> Self {
self.base.interpolation_vertices = Some(vertices);
self
}
pub fn boundary_degree_fallback(mut self, enabled: bool) -> Self {
self.base = self.base.boundary_degree_fallback(enabled);
self
}
pub fn auto_converge(mut self, tolerance: T) -> Self {
self.base.auto_converge = Some(tolerance);
self
}
pub fn compute_residuals(mut self, enabled: bool) -> Self {
self.base.compute_residuals = enabled;
self
}
pub fn return_robustness_weights(mut self, enabled: bool) -> Self {
self.base.return_robustness_weights = enabled;
self
}
pub fn return_diagnostics(mut self, enabled: bool) -> Self {
self.base.return_diagnostics = enabled;
self
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.base.chunk_size = size;
self
}
pub fn overlap(mut self, overlap: usize) -> Self {
self.base.overlap = overlap;
self
}
pub fn merge_strategy(mut self, strategy: MergeStrategy) -> Self {
self.base.merge_strategy = strategy;
self
}
pub fn build(self) -> Result<ParallelStreamingLoess<T>, LoessError> {
if let Some(ref err) = self.base.deferred_error {
return Err(err.clone());
}
Ok(ParallelStreamingLoess {
config: self,
processor: None,
})
}
}
pub struct ParallelStreamingLoess<
T: FloatLinalg + DistanceLinalg + SolverLinalg + Debug + Send + Sync,
> {
config: ParallelStreamingLoessBuilder<T>,
processor: Option<loess_rs::internals::adapters::streaming::StreamingLoess<T>>,
}
impl<T: FloatLinalg + DistanceLinalg + SolverLinalg + Float + Debug + Send + Sync + 'static>
ParallelStreamingLoess<T>
{
pub fn process_chunk<I1, I2>(&mut self, x: &I1, y: &I2) -> Result<LoessResult<T>, LoessError>
where
I1: LoessInput<T> + ?Sized,
I2: LoessInput<T> + ?Sized,
{
let x_slice = x.as_loess_slice()?;
let y_slice = y.as_loess_slice()?;
if self.processor.is_none() {
let mut builder = self.config.base.clone();
#[cfg(feature = "cpu")]
{
if builder.parallel.unwrap_or(true) {
builder.custom_smooth_pass = Some(smooth_pass_parallel);
builder.custom_cv_pass = Some(cv_pass_parallel);
builder.custom_interval_pass = Some(interval_pass_parallel);
builder.custom_vertex_pass = Some(vertex_pass_parallel);
builder.custom_kdtree_builder = Some(build_kdtree_parallel);
}
}
self.processor = Some(builder.build()?);
}
self.processor
.as_mut()
.unwrap()
.process_chunk(x_slice, y_slice)
}
pub fn finalize(&mut self) -> Result<LoessResult<T>, LoessError> {
if let Some(ref mut proc) = self.processor {
proc.finalize()
} else {
Ok(LoessResult {
x: Vec::new(),
dimensions: self.config.base.dimensions,
distance_metric: self.config.base.distance_metric.clone(),
polynomial_degree: self.config.base.polynomial_degree,
y: Vec::new(),
standard_errors: None,
confidence_lower: None,
confidence_upper: None,
prediction_lower: None,
prediction_upper: None,
residuals: None,
robustness_weights: None,
diagnostics: None,
iterations_used: None,
fraction_used: self.config.base.fraction,
cv_scores: None,
enp: None,
trace_hat: None,
delta1: None,
delta2: None,
residual_scale: None,
leverage: None,
})
}
}
pub fn reset(&mut self) {
if let Some(ref mut proc) = self.processor {
proc.reset();
}
}
}