scirs2_interpolate/
statistical.rs

1//! Advanced statistical interpolation methods
2//!
3//! This module provides statistical interpolation techniques that go beyond
4//! deterministic interpolation, including:
5//! - Bootstrap confidence intervals
6//! - Bayesian interpolation with posterior distributions
7//! - Quantile interpolation/regression
8//! - Robust interpolation methods
9//! - Stochastic interpolation for random fields
10
11#![allow(clippy::too_many_arguments)]
12#![allow(dead_code)]
13
14use crate::error::{InterpolateError, InterpolateResult};
15use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
16use scirs2_core::numeric::{Float, FromPrimitive};
17use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
18use scirs2_core::random::{Distribution, Normal, StandardNormal};
19use statrs::statistics::Statistics;
20use std::fmt::{Debug, Display};
21
22/// Configuration for bootstrap confidence intervals
23#[derive(Debug, Clone)]
24pub struct BootstrapConfig {
25    /// Number of bootstrap samples
26    pub n_samples: usize,
27    /// Confidence level (e.g., 0.95 for 95% CI)
28    pub confidence_level: f64,
29    /// Random seed for reproducibility
30    pub seed: Option<u64>,
31}
32
33impl Default for BootstrapConfig {
34    fn default() -> Self {
35        Self {
36            n_samples: 1000,
37            confidence_level: 0.95,
38            seed: None,
39        }
40    }
41}
42
43/// Result from bootstrap interpolation including confidence intervals
44#[derive(Debug, Clone)]
45pub struct BootstrapResult<T: Float> {
46    /// Point estimate (median of bootstrap samples)
47    pub estimate: Array1<T>,
48    /// Lower confidence bound
49    pub lower_bound: Array1<T>,
50    /// Upper confidence bound
51    pub upper_bound: Array1<T>,
52    /// Standard error estimate
53    pub std_error: Array1<T>,
54}
55
56/// Bootstrap interpolation with confidence intervals
57///
58/// This method performs interpolation with uncertainty quantification
59/// using bootstrap resampling of the input data.
60pub struct BootstrapInterpolator<T: Float> {
61    /// Configuration for bootstrap
62    config: BootstrapConfig,
63    /// Base interpolator factory
64    interpolator_factory:
65        Box<dyn Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn Fn(T) -> T>>>,
66}
67
68impl<T: Float + FromPrimitive + Debug + Display + std::iter::Sum> BootstrapInterpolator<T> {
69    /// Create a new bootstrap interpolator
70    pub fn new<F>(config: BootstrapConfig, interpolator_factory: F) -> Self
71    where
72        F: Fn(&ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Box<dyn Fn(T) -> T>> + 'static,
73    {
74        Self {
75            config,
76            interpolator_factory: Box::new(interpolator_factory),
77        }
78    }
79
80    /// Perform bootstrap interpolation at given points
81    pub fn interpolate(
82        &self,
83        x: &ArrayView1<T>,
84        y: &ArrayView1<T>,
85        xnew: &ArrayView1<T>,
86    ) -> InterpolateResult<BootstrapResult<T>> {
87        if x.len() != y.len() {
88            return Err(InterpolateError::DimensionMismatch(
89                "x and y must have the same length".to_string(),
90            ));
91        }
92
93        let n = x.len();
94        let m = xnew.len();
95        let mut rng = match self.config.seed {
96            Some(seed) => StdRng::seed_from_u64(seed),
97            None => {
98                let mut rng = scirs2_core::random::rng();
99                StdRng::from_rng(&mut rng)
100            }
101        };
102
103        // Storage for bootstrap samples
104        let mut bootstrap_results = Array2::<T>::zeros((self.config.n_samples, m));
105
106        // Perform bootstrap resampling
107        for i in 0..self.config.n_samples {
108            // Resample indices with replacement
109            let indices: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
110
111            // Create resampled data
112            let x_resampled = indices.iter().map(|&idx| x[idx]).collect::<Array1<_>>();
113            let y_resampled = indices.iter().map(|&idx| y[idx]).collect::<Array1<_>>();
114
115            // Create interpolator for this bootstrap sample
116            let interpolator =
117                (self.interpolator_factory)(&x_resampled.view(), &y_resampled.view())?;
118
119            // Evaluate at _new points
120            for (j, &x_val) in xnew.iter().enumerate() {
121                bootstrap_results[[i, j]] = interpolator(x_val);
122            }
123        }
124
125        // Calculate statistics
126        let alpha = T::from(1.0 - self.config.confidence_level).unwrap();
127        let lower_percentile = alpha / T::from(2.0).unwrap();
128        let upper_percentile = T::one() - alpha / T::from(2.0).unwrap();
129
130        let mut estimate = Array1::zeros(m);
131        let mut lower_bound = Array1::zeros(m);
132        let mut upper_bound = Array1::zeros(m);
133        let mut std_error = Array1::zeros(m);
134
135        for j in 0..m {
136            let column = bootstrap_results.index_axis(Axis(1), j);
137            let mut sorted_col = column.to_vec();
138            sorted_col.sort_by(|a, b| a.partial_cmp(b).unwrap());
139
140            // Median as point estimate
141            let median_idx = self.config.n_samples / 2;
142            estimate[j] = sorted_col[median_idx];
143
144            // Confidence bounds
145            let lower_idx = (lower_percentile * T::from(self.config.n_samples).unwrap())
146                .to_usize()
147                .unwrap();
148            let upper_idx = (upper_percentile * T::from(self.config.n_samples).unwrap())
149                .to_usize()
150                .unwrap();
151            lower_bound[j] = sorted_col[lower_idx];
152            upper_bound[j] = sorted_col[upper_idx];
153
154            // Standard error
155            let mean = column.mean().unwrap();
156            let variance = column
157                .iter()
158                .map(|&val| (val - mean) * (val - mean))
159                .sum::<T>()
160                / T::from(self.config.n_samples - 1).unwrap();
161            std_error[j] = variance.sqrt();
162        }
163
164        Ok(BootstrapResult {
165            estimate,
166            lower_bound,
167            upper_bound,
168            std_error,
169        })
170    }
171}
172
173/// Configuration for Bayesian interpolation
174pub struct BayesianConfig<T: Float> {
175    /// Prior mean function
176    pub prior_mean: Box<dyn Fn(T) -> T>,
177    /// Prior variance
178    pub prior_variance: T,
179    /// Measurement noise variance
180    pub noise_variance: T,
181    /// RBF kernel length scale parameter
182    pub length_scale: T,
183    /// Number of posterior samples to draw
184    pub n_posterior_samples: usize,
185}
186
187impl<T: Float + FromPrimitive> Default for BayesianConfig<T> {
188    fn default() -> Self {
189        Self {
190            prior_mean: Box::new(|_| T::zero()),
191            prior_variance: T::one(),
192            noise_variance: T::from(0.01).unwrap(),
193            length_scale: T::one(),
194            n_posterior_samples: 100,
195        }
196    }
197}
198
199impl<T: Float + FromPrimitive> BayesianConfig<T> {
200    /// Set the RBF kernel length scale parameter
201    pub fn with_length_scale(mut self, lengthscale: T) -> Self {
202        self.length_scale = lengthscale;
203        self
204    }
205
206    /// Set the prior variance
207    pub fn with_prior_variance(mut self, variance: T) -> Self {
208        self.prior_variance = variance;
209        self
210    }
211
212    /// Set the noise variance
213    pub fn with_noise_variance(mut self, variance: T) -> Self {
214        self.noise_variance = variance;
215        self
216    }
217
218    /// Set the number of posterior samples
219    pub fn with_n_posterior_samples(mut self, nsamples: usize) -> Self {
220        self.n_posterior_samples = nsamples;
221        self
222    }
223}
224
225/// Bayesian interpolation with full posterior distribution
226///
227/// This provides interpolation with full uncertainty quantification
228/// through Bayesian inference.
229pub struct BayesianInterpolator<T: Float> {
230    config: BayesianConfig<T>,
231    x_obs: Array1<T>,
232    y_obs: Array1<T>,
233}
234
235impl<
236        T: Float
237            + FromPrimitive
238            + Debug
239            + Display
240            + std::ops::AddAssign
241            + std::ops::SubAssign
242            + std::ops::MulAssign
243            + std::ops::DivAssign
244            + std::ops::RemAssign,
245    > BayesianInterpolator<T>
246{
247    /// Create a new Bayesian interpolator
248    pub fn new(
249        x: &ArrayView1<T>,
250        y: &ArrayView1<T>,
251        config: BayesianConfig<T>,
252    ) -> InterpolateResult<Self> {
253        if x.len() != y.len() {
254            return Err(InterpolateError::DimensionMismatch(
255                "x and y must have the same length".to_string(),
256            ));
257        }
258
259        Ok(Self {
260            config,
261            x_obs: x.to_owned(),
262            y_obs: y.to_owned(),
263        })
264    }
265
266    /// Get posterior mean at given points using proper Gaussian process regression
267    pub fn posterior_mean(&self, xnew: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
268        let n = self.x_obs.len();
269        let m = xnew.len();
270
271        if n == 0 {
272            return Err(InterpolateError::invalid_input(
273                "No observed data points".to_string(),
274            ));
275        }
276
277        // Compute covariance matrix K(X, X) + σ²I
278        let mut k_xx = Array2::<T>::zeros((n, n));
279        let length_scale = self.config.length_scale;
280
281        // Build covariance matrix with RBF kernel
282        for i in 0..n {
283            for j in 0..n {
284                let dist_sq = (self.x_obs[i] - self.x_obs[j]).powi(2);
285                k_xx[[i, j]] = self.config.prior_variance
286                    * (-dist_sq / (T::from(2.0).unwrap() * length_scale.powi(2))).exp();
287
288                // Add noise variance to diagonal
289                if i == j {
290                    k_xx[[i, j]] += self.config.noise_variance;
291                }
292            }
293        }
294
295        // Solve the linear system K * weights = y_obs using Cholesky decomposition
296        // This is more numerically stable than matrix inversion
297        let weights = match self.solve_gp_system(&k_xx.view(), &self.y_obs.view()) {
298            Ok(w) => w,
299            Err(_) => {
300                // Fallback to regularized system if Cholesky fails
301                let regularization = T::from(1e-6).unwrap();
302                for i in 0..n {
303                    k_xx[[i, i]] += regularization;
304                }
305                self.solve_gp_system(&k_xx.view(), &self.y_obs.view())?
306            }
307        };
308
309        // Compute cross-covariance K(X*, X)
310        let mut k_star_x = Array2::<T>::zeros((m, n));
311        for i in 0..m {
312            for j in 0..n {
313                let dist_sq = (xnew[i] - self.x_obs[j]).powi(2);
314                k_star_x[[i, j]] = self.config.prior_variance
315                    * (-dist_sq / (T::from(2.0).unwrap() * length_scale.powi(2))).exp();
316            }
317        }
318
319        // Compute posterior mean: μ* = K(X*, X) * weights
320        let mut mean = Array1::zeros(m);
321        for i in 0..m {
322            let mut sum = T::zero();
323            for j in 0..n {
324                sum += k_star_x[[i, j]] * weights[j];
325            }
326            // Add prior mean
327            mean[i] = (self.config.prior_mean)(xnew[i]) + sum;
328        }
329
330        Ok(mean)
331    }
332
333    /// Solve the GP linear system using available numerical methods
334    fn solve_gp_system(
335        &self,
336        k_matrix: &ArrayView2<T>,
337        y_obs: &ArrayView1<T>,
338    ) -> InterpolateResult<Array1<T>> {
339        use crate::structured_matrix::solve_dense_system;
340
341        // Try using the structured _matrix solver
342        match solve_dense_system(k_matrix, y_obs) {
343            Ok(solution) => Ok(solution),
344            Err(_) => {
345                // Additional fallback: use simple weighted average if _matrix is ill-conditioned
346                let n = y_obs.len();
347                let weights = Array1::from_elem(n, T::one() / T::from(n).unwrap());
348                Ok(weights)
349            }
350        }
351    }
352
353    /// Draw samples from the posterior distribution
354    pub fn posterior_samples(
355        &self,
356        xnew: &ArrayView1<T>,
357        n_samples: usize,
358    ) -> InterpolateResult<Array2<T>> {
359        let mean = self.posterior_mean(xnew)?;
360        let m = xnew.len();
361
362        // For computational efficiency, we use a simplified approach that captures
363        // the main posterior uncertainty while avoiding expensive matrix operations.
364        // A full implementation would compute the posterior covariance matrix:
365        // Σ* = K(X*, X*) - K(X*, X)[K(X, X) + σ²I]^(-1)K(X, X*)
366
367        let mut samples = Array2::zeros((n_samples, m));
368        let mut rng = scirs2_core::random::rng();
369
370        // Compute approximate posterior variance at each point
371        let length_scale = T::one();
372        for j in 0..m {
373            // Compute posterior variance as prior variance minus reduction from observations
374            let mut reduction_factor = T::zero();
375            let mut total_influence = T::zero();
376
377            for i in 0..self.x_obs.len() {
378                let dist_sq = (xnew[j] - self.x_obs[i]).powi(2);
379                let influence = (-dist_sq / (T::from(2.0).unwrap() * length_scale.powi(2))).exp();
380                total_influence += influence;
381                reduction_factor += influence * influence;
382            }
383
384            // Approximate posterior variance
385            let noise_ratio = self.config.noise_variance / self.config.prior_variance;
386            let posterior_var = self.config.prior_variance
387                * (T::one()
388                    - reduction_factor / (total_influence + noise_ratio + T::from(1e-8).unwrap()));
389
390            // Ensure positive variance
391            let std_dev = posterior_var.max(T::from(1e-12).unwrap()).sqrt();
392
393            // Draw _samples for this query point
394            for i in 0..n_samples {
395                if let Ok(normal) =
396                    Normal::new(mean[j].to_f64().unwrap(), std_dev.to_f64().unwrap())
397                {
398                    samples[[i, j]] = T::from(normal.sample(&mut rng)).unwrap();
399                } else {
400                    samples[[i, j]] = mean[j];
401                }
402            }
403        }
404
405        Ok(samples)
406    }
407}
408
409/// Quantile interpolation/regression
410///
411/// Interpolates specific quantiles of the response distribution
412pub struct QuantileInterpolator<T: Float> {
413    /// Quantile to interpolate (e.g., 0.5 for median)
414    quantile: T,
415    /// Bandwidth for local quantile estimation
416    bandwidth: T,
417}
418
419impl<T: Float + FromPrimitive + Debug + Display> QuantileInterpolator<T>
420where
421    T: std::iter::Sum<T> + for<'a> std::iter::Sum<&'a T>,
422{
423    /// Create a new quantile interpolator
424    pub fn new(quantile: T, bandwidth: T) -> InterpolateResult<Self> {
425        if quantile <= T::zero() || quantile >= T::one() {
426            return Err(InterpolateError::InvalidValue(
427                "Quantile must be between 0 and 1".to_string(),
428            ));
429        }
430
431        Ok(Self {
432            quantile,
433            bandwidth,
434        })
435    }
436
437    /// Interpolate quantile at given points
438    pub fn interpolate(
439        &self,
440        x: &ArrayView1<T>,
441        y: &ArrayView1<T>,
442        xnew: &ArrayView1<T>,
443    ) -> InterpolateResult<Array1<T>> {
444        if x.len() != y.len() {
445            return Err(InterpolateError::DimensionMismatch(
446                "x and y must have the same length".to_string(),
447            ));
448        }
449
450        let n = x.len();
451        let m = xnew.len();
452        let mut result = Array1::zeros(m);
453
454        // Local quantile regression
455        for j in 0..m {
456            let x_target = xnew[j];
457
458            // Compute weights based on distance
459            let mut weights = Vec::with_capacity(n);
460            let mut weighted_values = Vec::with_capacity(n);
461
462            for i in 0..n {
463                let dist = (x[i] - x_target).abs() / self.bandwidth;
464                let weight = if dist < T::one() {
465                    (T::one() - dist * dist * dist).powi(3) // Tricube kernel
466                } else {
467                    T::zero()
468                };
469
470                if weight > T::epsilon() {
471                    weights.push(weight);
472                    weighted_values.push((y[i], weight));
473                }
474            }
475
476            if weighted_values.is_empty() {
477                // No nearby points, use nearest neighbor
478                let nearest_idx = x
479                    .iter()
480                    .enumerate()
481                    .min_by_key(|(_, &xi)| ((xi - x_target).abs().to_f64().unwrap() * 1e6) as i64)
482                    .map(|(i_, _)| i_)
483                    .unwrap();
484                result[j] = y[nearest_idx];
485            } else {
486                // Sort by value
487                weighted_values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
488
489                // Find weighted quantile
490                let total_weight: T = weights.iter().sum();
491                let target_weight = self.quantile * total_weight;
492
493                let mut cumulative_weight = T::zero();
494                for (val, weight) in weighted_values {
495                    cumulative_weight = cumulative_weight + weight;
496                    if cumulative_weight >= target_weight {
497                        result[j] = val;
498                        break;
499                    }
500                }
501            }
502        }
503
504        Ok(result)
505    }
506}
507
508/// Robust interpolation methods resistant to outliers
509pub struct RobustInterpolator<T: Float> {
510    /// Tuning constant for robustness
511    tuning_constant: T,
512    /// Maximum iterations for iterative reweighting
513    max_iterations: usize,
514    /// Convergence tolerance
515    tolerance: T,
516}
517
518impl<T: Float + FromPrimitive + Debug + Display> RobustInterpolator<T> {
519    /// Create a new robust interpolator using M-estimation
520    pub fn new(_tuningconstant: T) -> Self {
521        Self {
522            tuning_constant: _tuningconstant,
523            max_iterations: 100,
524            tolerance: T::from(1e-6).unwrap(),
525        }
526    }
527
528    /// Perform robust interpolation using iteratively reweighted least squares
529    pub fn interpolate(
530        &self,
531        x: &ArrayView1<T>,
532        y: &ArrayView1<T>,
533        xnew: &ArrayView1<T>,
534    ) -> InterpolateResult<Array1<T>> {
535        // Use local polynomial regression with robust weights
536        let n = x.len();
537        let m = xnew.len();
538        let mut result = Array1::zeros(m);
539
540        for j in 0..m {
541            let x_target = xnew[j];
542
543            // Initial weights (uniform)
544            let mut weights = vec![T::one(); n];
545            let mut prev_estimate = T::zero();
546
547            // Iteratively reweighted least squares
548            for _iter in 0..self.max_iterations {
549                // Weighted linear regression
550                let mut sum_w = T::zero();
551                let mut sum_wx = T::zero();
552                let mut sum_wy = T::zero();
553                let mut sum_wxx = T::zero();
554                let mut sum_wxy = T::zero();
555
556                for i in 0..n {
557                    let w = weights[i];
558                    let dx = x[i] - x_target;
559                    sum_w = sum_w + w;
560                    sum_wx = sum_wx + w * dx;
561                    sum_wy = sum_wy + w * y[i];
562                    sum_wxx = sum_wxx + w * dx * dx;
563                    sum_wxy = sum_wxy + w * dx * y[i];
564                }
565
566                // Solve for coefficients
567                let det = sum_w * sum_wxx - sum_wx * sum_wx;
568                let estimate = if det.abs() > T::epsilon() {
569                    (sum_wxx * sum_wy - sum_wx * sum_wxy) / det
570                } else {
571                    sum_wy / sum_w
572                };
573
574                // Check convergence
575                if (estimate - prev_estimate).abs() < self.tolerance {
576                    result[j] = estimate;
577                    break;
578                }
579                prev_estimate = estimate;
580
581                // Update weights using Huber's psi function
582                for i in 0..n {
583                    let residual = y[i] - estimate;
584                    let scaled_residual = residual / self.tuning_constant;
585
586                    weights[i] = if scaled_residual.abs() <= T::one() {
587                        T::one()
588                    } else {
589                        T::one() / scaled_residual.abs()
590                    };
591                }
592            }
593
594            result[j] = prev_estimate;
595        }
596
597        Ok(result)
598    }
599}
600
601/// Stochastic interpolation for random fields
602///
603/// Provides interpolation that preserves the stochastic properties
604/// of the underlying random field.
605pub struct StochasticInterpolator<T: Float> {
606    /// Correlation length scale
607    correlation_length: T,
608    /// Field variance
609    field_variance: T,
610    /// Number of realizations to generate
611    n_realizations: usize,
612}
613
614impl<T: Float + FromPrimitive + Debug + Display> StochasticInterpolator<T> {
615    /// Create a new stochastic interpolator
616    pub fn new(correlation_length: T, field_variance: T, n_realizations: usize) -> Self {
617        Self {
618            correlation_length,
619            field_variance,
620            n_realizations,
621        }
622    }
623
624    /// Generate stochastic realizations of the interpolated field
625    pub fn interpolate_realizations(
626        &self,
627        x: &ArrayView1<T>,
628        y: &ArrayView1<T>,
629        xnew: &ArrayView1<T>,
630    ) -> InterpolateResult<Array2<T>> {
631        let n = x.len();
632        let m = xnew.len();
633        let mut realizations = Array2::zeros((self.n_realizations, m));
634
635        let mut rng = scirs2_core::random::rng();
636
637        for r in 0..self.n_realizations {
638            // Generate a realization using conditional simulation
639            for j in 0..m {
640                let x_target = xnew[j];
641
642                // Kriging interpolation with added noise
643                let mut weighted_sum = T::zero();
644                let mut weight_sum = T::zero();
645
646                for i in 0..n {
647                    let dist = (x[i] - x_target).abs() / self.correlation_length;
648                    let weight = (-dist * dist).exp();
649                    weighted_sum = weighted_sum + weight * y[i];
650                    weight_sum = weight_sum + weight;
651                }
652
653                let mean = if weight_sum > T::epsilon() {
654                    weighted_sum / weight_sum
655                } else {
656                    T::zero()
657                };
658
659                // Add stochastic component
660                let std_dev =
661                    (self.field_variance * (T::one() - weight_sum / T::from(n).unwrap())).sqrt();
662                let normal_sample: f64 = StandardNormal.sample(&mut rng);
663                let noise: T = T::from(normal_sample).unwrap() * std_dev;
664
665                realizations[[r, j]] = mean + noise;
666            }
667        }
668
669        Ok(realizations)
670    }
671
672    /// Get mean and variance of the stochastic interpolation
673    pub fn interpolate_statistics(
674        &self,
675        x: &ArrayView1<T>,
676        y: &ArrayView1<T>,
677        xnew: &ArrayView1<T>,
678    ) -> InterpolateResult<(Array1<T>, Array1<T>)> {
679        let realizations = self.interpolate_realizations(x, y, xnew)?;
680
681        let mean = realizations.mean_axis(Axis(0)).unwrap();
682        let variance = realizations.var_axis(Axis(0), T::from(1.0).unwrap());
683
684        Ok((mean, variance))
685    }
686}
687
688/// Factory functions for creating statistical interpolators
689/// Create a bootstrap interpolator with linear base interpolation
690#[allow(dead_code)]
691pub fn make_bootstrap_linear_interpolator<
692    T: Float + FromPrimitive + Debug + Display + 'static + std::iter::Sum,
693>(
694    config: BootstrapConfig,
695) -> BootstrapInterpolator<T> {
696    BootstrapInterpolator::new(config, |x, y| {
697        // Create a simple linear interpolator
698        let x_owned = x.to_owned();
699        let y_owned = y.to_owned();
700        Ok(Box::new(move |xnew| {
701            // Simple linear interpolation
702            if xnew <= x_owned[0] {
703                y_owned[0]
704            } else if xnew >= x_owned[x_owned.len() - 1] {
705                y_owned[y_owned.len() - 1]
706            } else {
707                // Find surrounding points
708                let mut i = 0;
709                for j in 1..x_owned.len() {
710                    if xnew <= x_owned[j] {
711                        i = j - 1;
712                        break;
713                    }
714                }
715
716                let alpha = (xnew - x_owned[i]) / (x_owned[i + 1] - x_owned[i]);
717                y_owned[i] * (T::one() - alpha) + y_owned[i + 1] * alpha
718            }
719        }))
720    })
721}
722
723/// Create a Bayesian interpolator with default configuration
724#[allow(dead_code)]
725pub fn make_bayesian_interpolator<T: crate::traits::InterpolationFloat>(
726    x: &ArrayView1<T>,
727    y: &ArrayView1<T>,
728) -> InterpolateResult<BayesianInterpolator<T>> {
729    BayesianInterpolator::new(x, y, BayesianConfig::default())
730}
731
732/// Create a median (0.5 quantile) interpolator
733#[allow(dead_code)]
734pub fn make_median_interpolator<T>(bandwidth: T) -> InterpolateResult<QuantileInterpolator<T>>
735where
736    T: Float + FromPrimitive + Debug + Display + std::iter::Sum<T> + for<'a> std::iter::Sum<&'a T>,
737{
738    QuantileInterpolator::new(T::from(0.5).unwrap(), bandwidth)
739}
740
741/// Create a robust interpolator with default Huber tuning
742#[allow(dead_code)]
743pub fn make_robust_interpolator<T: crate::traits::InterpolationFloat>() -> RobustInterpolator<T> {
744    RobustInterpolator::new(T::from(1.345).unwrap()) // Huber's recommended value
745}
746
747/// Create a stochastic interpolator with default parameters
748#[allow(dead_code)]
749pub fn make_stochastic_interpolator<T: crate::traits::InterpolationFloat>(
750    correlation_length: T,
751) -> StochasticInterpolator<T> {
752    StochasticInterpolator::new(correlation_length, T::one(), 100)
753}
754
755/// Ensemble interpolation combining multiple methods
756///
757/// Provides interpolation using an ensemble of different methods
758/// to improve robustness and uncertainty quantification.
759pub struct EnsembleInterpolator<T: Float> {
760    /// Weight for each interpolation method
761    weights: Array1<T>,
762    /// Interpolation methods in the ensemble
763    methods: Vec<
764        Box<dyn Fn(&ArrayView1<T>, &ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Array1<T>>>,
765    >,
766    /// Whether to normalize weights
767    normalize_weights: bool,
768}
769
770impl<T: crate::traits::InterpolationFloat> EnsembleInterpolator<T> {
771    /// Create a new ensemble interpolator
772    pub fn new() -> Self {
773        Self {
774            weights: Array1::zeros(0),
775            methods: Vec::new(),
776            normalize_weights: true,
777        }
778    }
779
780    /// Add a linear interpolation method to the ensemble
781    pub fn add_linear_method(mut self, weight: T) -> Self {
782        self.weights = if self.weights.is_empty() {
783            Array1::from_vec(vec![weight])
784        } else {
785            let mut new_weights = self.weights.to_vec();
786            new_weights.push(weight);
787            Array1::from_vec(new_weights)
788        };
789
790        self.methods.push(Box::new(|x, y, xnew| {
791            let mut result = Array1::zeros(xnew.len());
792            for (i, &x_val) in xnew.iter().enumerate() {
793                // Linear interpolation
794                if x_val <= x[0] {
795                    result[i] = y[0];
796                } else if x_val >= x[x.len() - 1] {
797                    result[i] = y[y.len() - 1];
798                } else {
799                    // Find surrounding points
800                    for j in 1..x.len() {
801                        if x_val <= x[j] {
802                            let alpha = (x_val - x[j - 1]) / (x[j] - x[j - 1]);
803                            result[i] = y[j - 1] * (T::one() - alpha) + y[j] * alpha;
804                            break;
805                        }
806                    }
807                }
808            }
809            Ok(result)
810        }));
811        self
812    }
813
814    /// Add a cubic interpolation method to the ensemble
815    pub fn add_cubic_method(mut self, weight: T) -> Self {
816        self.weights = if self.weights.is_empty() {
817            Array1::from_vec(vec![weight])
818        } else {
819            let mut new_weights = self.weights.to_vec();
820            new_weights.push(weight);
821            Array1::from_vec(new_weights)
822        };
823
824        self.methods.push(Box::new(|x, y, xnew| {
825            // Cubic spline interpolation using natural boundary conditions
826            use crate::spline::CubicSpline;
827
828            // Need at least 3 points for cubic spline
829            if x.len() < 3 {
830                return Err(InterpolateError::invalid_input(
831                    "Cubic spline requires at least 3 data points".to_string(),
832                ));
833            }
834
835            // Create cubic spline with natural boundary conditions
836            let spline = CubicSpline::new(x, y)?;
837
838            // Evaluate at all query points
839            let mut result = Array1::zeros(xnew.len());
840            for (i, &x_val) in xnew.iter().enumerate() {
841                // Handle extrapolation by clamping to boundary values
842                if x_val < x[0] {
843                    result[i] = y[0];
844                } else if x_val > x[x.len() - 1] {
845                    result[i] = y[y.len() - 1];
846                } else {
847                    // Evaluate cubic spline within the valid range
848                    result[i] = spline.evaluate(x_val)?;
849                }
850            }
851            Ok(result)
852        }));
853        self
854    }
855
856    /// Perform ensemble interpolation
857    pub fn interpolate(
858        &self,
859        x: &ArrayView1<T>,
860        y: &ArrayView1<T>,
861        xnew: &ArrayView1<T>,
862    ) -> InterpolateResult<Array1<T>> {
863        if self.methods.is_empty() {
864            return Err(InterpolateError::InvalidState(
865                "No interpolation methods in ensemble".to_string(),
866            ));
867        }
868
869        let mut weighted_results = Array1::zeros(xnew.len());
870        let mut total_weight = T::zero();
871
872        for (i, method) in self.methods.iter().enumerate() {
873            let result = method(x, y, xnew)?;
874            let weight = self.weights[i];
875
876            for j in 0..xnew.len() {
877                weighted_results[j] += weight * result[j];
878            }
879            total_weight += weight;
880        }
881
882        // Normalize if requested
883        if self.normalize_weights && total_weight > T::zero() {
884            for val in weighted_results.iter_mut() {
885                *val /= total_weight;
886            }
887        }
888
889        Ok(weighted_results)
890    }
891
892    /// Get ensemble variance (measure of uncertainty)
893    pub fn interpolate_with_variance(
894        &self,
895        x: &ArrayView1<T>,
896        y: &ArrayView1<T>,
897        xnew: &ArrayView1<T>,
898    ) -> InterpolateResult<(Array1<T>, Array1<T>)> {
899        if self.methods.is_empty() {
900            return Err(InterpolateError::InvalidState(
901                "No interpolation methods in ensemble".to_string(),
902            ));
903        }
904
905        let mut all_results = Vec::new();
906
907        // Collect results from all methods
908        for method in self.methods.iter() {
909            let result = method(x, y, xnew)?;
910            all_results.push(result);
911        }
912
913        // Compute weighted mean
914        let mut weighted_mean = Array1::zeros(xnew.len());
915        let mut total_weight = T::zero();
916
917        for (i, result) in all_results.iter().enumerate() {
918            let weight = self.weights[i];
919            for j in 0..xnew.len() {
920                weighted_mean[j] += weight * result[j];
921            }
922            total_weight += weight;
923        }
924
925        if total_weight > T::zero() {
926            for val in weighted_mean.iter_mut() {
927                *val /= total_weight;
928            }
929        }
930
931        // Compute weighted variance
932        let mut variance = Array1::zeros(xnew.len());
933        if all_results.len() > 1 {
934            for (i, result) in all_results.iter().enumerate() {
935                let weight = self.weights[i];
936                for j in 0..xnew.len() {
937                    let diff = result[j] - weighted_mean[j];
938                    variance[j] += weight * diff * diff;
939                }
940            }
941
942            if total_weight > T::zero() {
943                for val in variance.iter_mut() {
944                    *val /= total_weight;
945                }
946            }
947        }
948
949        Ok((weighted_mean, variance))
950    }
951}
952
953impl<T: crate::traits::InterpolationFloat> Default for EnsembleInterpolator<T> {
954    fn default() -> Self {
955        Self::new()
956    }
957}
958
959/// Cross-validation based uncertainty estimation
960///
961/// Provides uncertainty estimates using leave-one-out cross-validation
962/// or k-fold cross-validation.
963pub struct CrossValidationUncertainty {
964    /// Number of folds for k-fold CV (if 0, use leave-one-out)
965    k_folds: usize,
966    /// Random seed for fold assignment
967    seed: Option<u64>,
968}
969
970impl CrossValidationUncertainty {
971    /// Create a new cross-validation uncertainty estimator
972    pub fn new(_kfolds: usize) -> Self {
973        Self {
974            k_folds: _kfolds,
975            seed: None,
976        }
977    }
978
979    /// Set random seed for reproducible fold assignment
980    pub fn with_seed(mut self, seed: u64) -> Self {
981        self.seed = Some(seed);
982        self
983    }
984
985    /// Estimate uncertainty using cross-validation
986    pub fn estimate_uncertainty<T, F>(
987        &self,
988        x: &ArrayView1<T>,
989        y: &ArrayView1<T>,
990        xnew: &ArrayView1<T>,
991        interpolator_factory: F,
992    ) -> InterpolateResult<(Array1<T>, Array1<T>)>
993    where
994        T: Clone
995            + Copy
996            + scirs2_core::numeric::Float
997            + scirs2_core::numeric::FromPrimitive
998            + std::iter::Sum,
999        F: Fn(&ArrayView1<T>, &ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Array1<T>>,
1000    {
1001        let n = x.len();
1002        let _m = xnew.len();
1003
1004        if self.k_folds == 0 || self.k_folds >= n {
1005            // Leave-one-out cross-validation
1006            self.leave_one_out_uncertainty(x, y, xnew, interpolator_factory)
1007        } else {
1008            // K-fold cross-validation
1009            self.k_fold_uncertainty(x, y, xnew, interpolator_factory)
1010        }
1011    }
1012
1013    fn leave_one_out_uncertainty<T, F>(
1014        &self,
1015        x: &ArrayView1<T>,
1016        y: &ArrayView1<T>,
1017        xnew: &ArrayView1<T>,
1018        interpolator_factory: F,
1019    ) -> InterpolateResult<(Array1<T>, Array1<T>)>
1020    where
1021        T: Clone
1022            + Copy
1023            + scirs2_core::numeric::Float
1024            + scirs2_core::numeric::FromPrimitive
1025            + std::iter::Sum,
1026        F: Fn(&ArrayView1<T>, &ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Array1<T>>,
1027    {
1028        let n = x.len();
1029        let m = xnew.len();
1030        let mut predictions = Array2::zeros((n, m));
1031
1032        // Leave-one-out cross-validation
1033        for i in 0..n {
1034            // Create training set without point i
1035            let mut x_train = Vec::new();
1036            let mut y_train = Vec::new();
1037
1038            for j in 0..n {
1039                if j != i {
1040                    x_train.push(x[j]);
1041                    y_train.push(y[j]);
1042                }
1043            }
1044
1045            let x_train_array = Array1::from_vec(x_train);
1046            let y_train_array = Array1::from_vec(y_train);
1047
1048            // Train on reduced dataset and predict
1049            let pred = interpolator_factory(&x_train_array.view(), &y_train_array.view(), xnew)?;
1050            for j in 0..m {
1051                predictions[[i, j]] = pred[j];
1052            }
1053        }
1054
1055        // Compute mean and variance of predictions
1056        let mut mean = Array1::zeros(m);
1057        let mut variance = Array1::zeros(m);
1058
1059        for j in 0..m {
1060            let col = predictions.column(j);
1061            let sum: T = col.iter().copied().sum();
1062            mean[j] = sum / T::from(n).unwrap();
1063
1064            let var_sum: T = col
1065                .iter()
1066                .map(|&val| (val - mean[j]) * (val - mean[j]))
1067                .sum();
1068            variance[j] = var_sum / T::from(n - 1).unwrap();
1069        }
1070
1071        Ok((mean, variance))
1072    }
1073
1074    fn k_fold_uncertainty<T, F>(
1075        &self,
1076        x: &ArrayView1<T>,
1077        y: &ArrayView1<T>,
1078        xnew: &ArrayView1<T>,
1079        interpolator_factory: F,
1080    ) -> InterpolateResult<(Array1<T>, Array1<T>)>
1081    where
1082        T: Clone
1083            + Copy
1084            + scirs2_core::numeric::Float
1085            + scirs2_core::numeric::FromPrimitive
1086            + std::iter::Sum,
1087        F: Fn(&ArrayView1<T>, &ArrayView1<T>, &ArrayView1<T>) -> InterpolateResult<Array1<T>>,
1088    {
1089        let n = x.len();
1090        let m = xnew.len();
1091        let fold_size = n / self.k_folds;
1092        let mut predictions = Vec::new();
1093
1094        let mut rng = match self.seed {
1095            Some(seed) => StdRng::seed_from_u64(seed),
1096            None => {
1097                let mut rng = scirs2_core::random::rng();
1098                StdRng::from_rng(&mut rng)
1099            }
1100        };
1101
1102        // Create shuffled indices
1103        let mut indices: Vec<usize> = (0..n).collect();
1104        use scirs2_core::random::seq::SliceRandom;
1105        indices.shuffle(&mut rng);
1106
1107        // K-fold cross-validation
1108        for fold in 0..self.k_folds {
1109            let start_idx = fold * fold_size;
1110            let end_idx = if fold == self.k_folds - 1 {
1111                n
1112            } else {
1113                (fold + 1) * fold_size
1114            };
1115
1116            // Create training set (excluding current fold)
1117            let mut x_train = Vec::new();
1118            let mut y_train = Vec::new();
1119
1120            for &idx in &indices[..start_idx] {
1121                x_train.push(x[idx]);
1122                y_train.push(y[idx]);
1123            }
1124            for &idx in &indices[end_idx..] {
1125                x_train.push(x[idx]);
1126                y_train.push(y[idx]);
1127            }
1128
1129            let x_train_array = Array1::from_vec(x_train);
1130            let y_train_array = Array1::from_vec(y_train);
1131
1132            // Train and predict
1133            let pred = interpolator_factory(&x_train_array.view(), &y_train_array.view(), xnew)?;
1134            predictions.push(pred);
1135        }
1136
1137        // Compute statistics across folds
1138        let mut mean = Array1::zeros(m);
1139        let mut variance = Array1::zeros(m);
1140
1141        for j in 0..m {
1142            let values: Vec<T> = predictions.iter().map(|pred| pred[j]).collect();
1143            let sum: T = values.iter().copied().sum();
1144            mean[j] = sum / T::from(self.k_folds).unwrap();
1145
1146            let var_sum: T = values
1147                .iter()
1148                .map(|&val| (val - mean[j]) * (val - mean[j]))
1149                .sum();
1150            variance[j] = var_sum / T::from(self.k_folds - 1).unwrap();
1151        }
1152
1153        Ok((mean, variance))
1154    }
1155}
1156
1157/// Create an ensemble interpolator with linear and cubic methods
1158#[allow(dead_code)]
1159pub fn make_ensemble_interpolator<
1160    T: Float
1161        + FromPrimitive
1162        + Debug
1163        + Display
1164        + Copy
1165        + std::iter::Sum
1166        + crate::traits::InterpolationFloat,
1167>() -> EnsembleInterpolator<T> {
1168    EnsembleInterpolator::new()
1169        .add_linear_method(T::from(0.6).unwrap())
1170        .add_cubic_method(T::from(0.4).unwrap())
1171}
1172
1173/// Create a cross-validation uncertainty estimator with leave-one-out
1174#[allow(dead_code)]
1175pub fn make_loocv_uncertainty() -> CrossValidationUncertainty {
1176    CrossValidationUncertainty::new(0) // 0 means leave-one-out
1177}
1178
1179/// Create a cross-validation uncertainty estimator with k-fold CV
1180#[allow(dead_code)]
1181pub fn make_kfold_uncertainty(k: usize) -> CrossValidationUncertainty {
1182    CrossValidationUncertainty::new(k)
1183}
1184
1185/// Isotonic (monotonic) regression interpolator
1186///
1187/// Performs interpolation while maintaining monotonicity constraints.
1188/// This is useful for dose-response relationships and other applications
1189/// where the underlying relationship must be monotonic.
1190#[derive(Debug, Clone)]
1191pub struct IsotonicInterpolator<T: Float> {
1192    /// Fitted isotonic values at training points
1193    fitted_values: Array1<T>,
1194    /// Training x coordinates (sorted)
1195    x_data: Array1<T>,
1196    /// Whether interpolation should be increasing (true) or decreasing (false)
1197    increasing: bool,
1198}
1199
1200impl<T: Float + FromPrimitive + Debug + Display + Copy + std::iter::Sum> IsotonicInterpolator<T> {
1201    /// Create a new isotonic interpolator
1202    pub fn new(x: &ArrayView1<T>, y: &ArrayView1<T>, increasing: bool) -> InterpolateResult<Self> {
1203        if x.len() != y.len() {
1204            return Err(InterpolateError::DimensionMismatch(
1205                "x and y must have the same length".to_string(),
1206            ));
1207        }
1208
1209        if x.len() < 2 {
1210            return Err(InterpolateError::invalid_input(
1211                "Need at least 2 points for isotonic regression".to_string(),
1212            ));
1213        }
1214
1215        // Sort by x values
1216        let mut indices: Vec<usize> = (0..x.len()).collect();
1217        indices.sort_by(|&i, &j| x[i].partial_cmp(&x[j]).unwrap());
1218
1219        let x_sorted: Array1<T> = indices.iter().map(|&i| x[i]).collect();
1220        let y_sorted: Array1<T> = indices.iter().map(|&i| y[i]).collect();
1221
1222        // Apply pool-adjacent-violators algorithm
1223        let fitted_values = Self::pool_adjacent_violators(&y_sorted.view(), increasing)?;
1224
1225        Ok(Self {
1226            fitted_values,
1227            x_data: x_sorted,
1228            increasing,
1229        })
1230    }
1231
1232    /// Pool-adjacent-violators algorithm for isotonic regression
1233    fn pool_adjacent_violators(
1234        y: &ArrayView1<T>,
1235        increasing: bool,
1236    ) -> InterpolateResult<Array1<T>> {
1237        let n = y.len();
1238        let mut fitted = y.to_owned();
1239        let mut weights = Array1::<T>::ones(n);
1240
1241        loop {
1242            let mut changed = false;
1243
1244            for i in 0..n - 1 {
1245                let violates = if increasing {
1246                    fitted[i] > fitted[i + 1]
1247                } else {
1248                    fitted[i] < fitted[i + 1]
1249                };
1250
1251                if violates {
1252                    // Pool adjacent blocks
1253                    let total_weight = weights[i] + weights[i + 1];
1254                    let weighted_mean =
1255                        (fitted[i] * weights[i] + fitted[i + 1] * weights[i + 1]) / total_weight;
1256
1257                    fitted[i] = weighted_mean;
1258                    fitted[i + 1] = weighted_mean;
1259                    weights[i] = total_weight;
1260                    weights[i + 1] = total_weight;
1261
1262                    changed = true;
1263                }
1264            }
1265
1266            if !changed {
1267                break;
1268            }
1269        }
1270
1271        Ok(fitted)
1272    }
1273
1274    /// Interpolate at new points
1275    pub fn interpolate(&self, xnew: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
1276        let mut result = Array1::zeros(xnew.len());
1277
1278        for (i, &x) in xnew.iter().enumerate() {
1279            // Find position in sorted data
1280            let idx = match self
1281                .x_data
1282                .as_slice()
1283                .unwrap()
1284                .binary_search_by(|&probe| probe.partial_cmp(&x).unwrap())
1285            {
1286                Ok(exact_idx) => {
1287                    result[i] = self.fitted_values[exact_idx];
1288                    continue;
1289                }
1290                Err(insert_idx) => insert_idx,
1291            };
1292
1293            // Linear interpolation between adjacent fitted values
1294            if idx == 0 {
1295                result[i] = self.fitted_values[0];
1296            } else if idx >= self.x_data.len() {
1297                result[i] = self.fitted_values[self.x_data.len() - 1];
1298            } else {
1299                let x0 = self.x_data[idx - 1];
1300                let x1 = self.x_data[idx];
1301                let y0 = self.fitted_values[idx - 1];
1302                let y1 = self.fitted_values[idx];
1303
1304                let t = (x - x0) / (x1 - x0);
1305                result[i] = y0 + t * (y1 - y0);
1306            }
1307        }
1308
1309        Ok(result)
1310    }
1311}
1312
1313/// Kernel Density Estimation (KDE) based interpolator
1314///
1315/// Uses kernel density estimation to create smooth interpolations
1316/// based on probability density functions.
1317#[derive(Debug, Clone)]
1318pub struct KDEInterpolator<T: Float> {
1319    /// Training data points
1320    x_data: Array1<T>,
1321    y_data: Array1<T>,
1322    /// Kernel bandwidth
1323    bandwidth: T,
1324    /// Kernel type
1325    kernel_type: KDEKernel,
1326}
1327
1328/// Kernel types for KDE interpolation
1329#[derive(Debug, Clone, Copy, PartialEq)]
1330pub enum KDEKernel {
1331    /// Gaussian (normal) kernel
1332    Gaussian,
1333    /// Epanechnikov kernel (more efficient)
1334    Epanechnikov,
1335    /// Triangular kernel
1336    Triangular,
1337    /// Uniform (box) kernel
1338    Uniform,
1339}
1340
1341impl<T: Float + FromPrimitive + Debug + Display + Copy> KDEInterpolator<T> {
1342    /// Create a new KDE interpolator
1343    pub fn new(
1344        x: &ArrayView1<T>,
1345        y: &ArrayView1<T>,
1346        bandwidth: T,
1347        kernel_type: KDEKernel,
1348    ) -> InterpolateResult<Self> {
1349        if x.len() != y.len() {
1350            return Err(InterpolateError::DimensionMismatch(
1351                "x and y must have the same length".to_string(),
1352            ));
1353        }
1354
1355        if bandwidth <= T::zero() {
1356            return Err(InterpolateError::invalid_input(
1357                "Bandwidth must be positive".to_string(),
1358            ));
1359        }
1360
1361        Ok(Self {
1362            x_data: x.to_owned(),
1363            y_data: y.to_owned(),
1364            bandwidth,
1365            kernel_type,
1366        })
1367    }
1368
1369    /// Kernel function evaluation
1370    fn kernel(&self, u: T) -> T {
1371        match self.kernel_type {
1372            KDEKernel::Gaussian => {
1373                let pi = T::from(std::f64::consts::PI).unwrap();
1374                let two = T::from(2.0).unwrap();
1375                let exp_arg = -u * u / two;
1376                exp_arg.exp() / (two * pi).sqrt()
1377            }
1378            KDEKernel::Epanechnikov => {
1379                if u.abs() <= T::one() {
1380                    let three_fourths = T::from(0.75).unwrap();
1381                    three_fourths * (T::one() - u * u)
1382                } else {
1383                    T::zero()
1384                }
1385            }
1386            KDEKernel::Triangular => {
1387                if u.abs() <= T::one() {
1388                    T::one() - u.abs()
1389                } else {
1390                    T::zero()
1391                }
1392            }
1393            KDEKernel::Uniform => {
1394                if u.abs() <= T::one() {
1395                    T::from(0.5).unwrap()
1396                } else {
1397                    T::zero()
1398                }
1399            }
1400        }
1401    }
1402
1403    /// Interpolate at new points using KDE
1404    pub fn interpolate(&self, xnew: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
1405        let mut result = Array1::zeros(xnew.len());
1406
1407        for (i, &x) in xnew.iter().enumerate() {
1408            let mut weighted_sum = T::zero();
1409            let mut weight_sum = T::zero();
1410
1411            for j in 0..self.x_data.len() {
1412                let u = (x - self.x_data[j]) / self.bandwidth;
1413                let kernel_weight = self.kernel(u);
1414
1415                weighted_sum = weighted_sum + kernel_weight * self.y_data[j];
1416                weight_sum = weight_sum + kernel_weight;
1417            }
1418
1419            if weight_sum > T::zero() {
1420                result[i] = weighted_sum / weight_sum;
1421            } else {
1422                // Fallback to nearest neighbor
1423                let mut min_dist = T::infinity();
1424                let mut nearest_y = self.y_data[0];
1425
1426                for j in 0..self.x_data.len() {
1427                    let dist = (x - self.x_data[j]).abs();
1428                    if dist < min_dist {
1429                        min_dist = dist;
1430                        nearest_y = self.y_data[j];
1431                    }
1432                }
1433
1434                result[i] = nearest_y;
1435            }
1436        }
1437
1438        Ok(result)
1439    }
1440}
1441
1442/// Empirical Bayes interpolator
1443///
1444/// Uses empirical Bayes methods for shrinkage-based interpolation.
1445/// Particularly useful when dealing with multiple related functions
1446/// or when prior information is available.
1447#[derive(Debug, Clone)]
1448pub struct EmpiricalBayesInterpolator<T: Float> {
1449    /// Training data
1450    x_data: Array1<T>,
1451    y_data: Array1<T>,
1452    /// Shrinkage parameters
1453    shrinkage_factor: T,
1454    /// Prior mean function
1455    prior_mean: T,
1456    /// Noise variance estimate
1457    noise_variance: T,
1458}
1459
1460impl<T: Float + FromPrimitive + Debug + Display + Copy + std::iter::Sum>
1461    EmpiricalBayesInterpolator<T>
1462{
1463    /// Create a new empirical Bayes interpolator
1464    pub fn new(x: &ArrayView1<T>, y: &ArrayView1<T>) -> InterpolateResult<Self> {
1465        if x.len() != y.len() {
1466            return Err(InterpolateError::DimensionMismatch(
1467                "x and y must have the same length".to_string(),
1468            ));
1469        }
1470
1471        if x.len() < 3 {
1472            return Err(InterpolateError::invalid_input(
1473                "Need at least 3 points for empirical Bayes".to_string(),
1474            ));
1475        }
1476
1477        // Estimate prior mean (overall mean)
1478        let prior_mean = y.iter().copied().sum::<T>() / T::from(y.len()).unwrap();
1479
1480        // Estimate noise variance using residuals
1481        let residuals: Array1<T> = y.iter().map(|&yi| yi - prior_mean).collect();
1482        let noise_variance =
1483            residuals.iter().map(|&r| r * r).sum::<T>() / T::from(residuals.len() - 1).unwrap();
1484
1485        // Compute shrinkage factor using James-Stein type estimator
1486        let signal_variance = noise_variance.max(T::from(1e-10).unwrap());
1487        let shrinkage_factor = noise_variance / (noise_variance + signal_variance);
1488
1489        Ok(Self {
1490            x_data: x.to_owned(),
1491            y_data: y.to_owned(),
1492            shrinkage_factor,
1493            prior_mean,
1494            noise_variance,
1495        })
1496    }
1497
1498    /// Create empirical Bayes interpolator with custom prior
1499    pub fn with_prior(
1500        x: &ArrayView1<T>,
1501        y: &ArrayView1<T>,
1502        prior_mean: T,
1503        shrinkage_factor: T,
1504    ) -> InterpolateResult<Self> {
1505        if x.len() != y.len() {
1506            return Err(InterpolateError::DimensionMismatch(
1507                "x and y must have the same length".to_string(),
1508            ));
1509        }
1510
1511        let residuals: Array1<T> = y.iter().map(|&yi| yi - prior_mean).collect();
1512        let noise_variance =
1513            residuals.iter().map(|&r| r * r).sum::<T>() / T::from(residuals.len().max(1)).unwrap();
1514
1515        Ok(Self {
1516            x_data: x.to_owned(),
1517            y_data: y.to_owned(),
1518            shrinkage_factor,
1519            prior_mean,
1520            noise_variance,
1521        })
1522    }
1523
1524    /// Interpolate using empirical Bayes shrinkage
1525    pub fn interpolate(&self, xnew: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
1526        let mut result = Array1::zeros(xnew.len());
1527
1528        for (i, &x) in xnew.iter().enumerate() {
1529            // Find nearest neighbors for local estimation
1530            let mut distances: Vec<(T, usize)> = self
1531                .x_data
1532                .iter()
1533                .enumerate()
1534                .map(|(j, &xi)| ((x - xi).abs(), j))
1535                .collect();
1536            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
1537
1538            // Use k nearest neighbors (k = 3 or n/2, whichever is smaller)
1539            let k = (3_usize).min(self.x_data.len() / 2).max(1);
1540            let mut local_mean = T::zero();
1541            let mut total_weight = T::zero();
1542
1543            for &(dist, j) in distances.iter().take(k) {
1544                let weight = if dist == T::zero() {
1545                    T::one()
1546                } else {
1547                    T::one() / (T::one() + dist)
1548                };
1549                local_mean = local_mean + weight * self.y_data[j];
1550                total_weight = total_weight + weight;
1551            }
1552
1553            if total_weight > T::zero() {
1554                local_mean = local_mean / total_weight;
1555            } else {
1556                local_mean = self.prior_mean;
1557            }
1558
1559            // Apply empirical Bayes shrinkage
1560            let shrunk_estimate = (T::one() - self.shrinkage_factor) * local_mean
1561                + self.shrinkage_factor * self.prior_mean;
1562
1563            result[i] = shrunk_estimate;
1564        }
1565
1566        Ok(result)
1567    }
1568
1569    /// Get shrinkage factor
1570    pub fn get_shrinkage_factor(&self) -> T {
1571        self.shrinkage_factor
1572    }
1573
1574    /// Get prior mean
1575    pub fn get_prior_mean(&self) -> T {
1576        self.prior_mean
1577    }
1578
1579    /// Get noise variance estimate
1580    pub fn get_noise_variance(&self) -> T {
1581        self.noise_variance
1582    }
1583}
1584
1585/// Convenience function to create an isotonic interpolator (increasing)
1586#[allow(dead_code)]
1587pub fn make_isotonic_interpolator<
1588    T: Float + FromPrimitive + Debug + Display + Copy + std::iter::Sum,
1589>(
1590    x: &ArrayView1<T>,
1591    y: &ArrayView1<T>,
1592) -> InterpolateResult<IsotonicInterpolator<T>> {
1593    IsotonicInterpolator::new(x, y, true)
1594}
1595
1596/// Convenience function to create a decreasing isotonic interpolator
1597#[allow(dead_code)]
1598pub fn make_decreasing_isotonic_interpolator<
1599    T: Float + FromPrimitive + Debug + Display + Copy + std::iter::Sum,
1600>(
1601    x: &ArrayView1<T>,
1602    y: &ArrayView1<T>,
1603) -> InterpolateResult<IsotonicInterpolator<T>> {
1604    IsotonicInterpolator::new(x, y, false)
1605}
1606
1607/// Convenience function to create a KDE interpolator with Gaussian kernel
1608#[allow(dead_code)]
1609pub fn make_kde_interpolator<T: crate::traits::InterpolationFloat + Copy>(
1610    x: &ArrayView1<T>,
1611    y: &ArrayView1<T>,
1612    bandwidth: T,
1613) -> InterpolateResult<KDEInterpolator<T>> {
1614    KDEInterpolator::new(x, y, bandwidth, KDEKernel::Gaussian)
1615}
1616
1617/// Convenience function to create a KDE interpolator with automatic bandwidth selection
1618#[allow(dead_code)]
1619pub fn make_auto_kde_interpolator<
1620    T: Float + FromPrimitive + Debug + Display + Copy + std::iter::Sum,
1621>(
1622    x: &ArrayView1<T>,
1623    y: &ArrayView1<T>,
1624) -> InterpolateResult<KDEInterpolator<T>> {
1625    // Scott's rule for bandwidth selection
1626    let n = T::from(x.len()).unwrap();
1627    let x_std = {
1628        let mean = x.iter().copied().sum::<T>() / n;
1629        let variance = x.iter().map(|&xi| (xi - mean) * (xi - mean)).sum::<T>() / (n - T::one());
1630        variance.sqrt()
1631    };
1632
1633    let bandwidth = x_std * n.powf(-T::from(0.2).unwrap()); // n^(-1/5)
1634    KDEInterpolator::new(x, y, bandwidth, KDEKernel::Gaussian)
1635}
1636
1637/// Convenience function to create an empirical Bayes interpolator
1638#[allow(dead_code)]
1639pub fn make_empirical_bayes_interpolator<
1640    T: Float + FromPrimitive + Debug + Display + Copy + std::iter::Sum,
1641>(
1642    x: &ArrayView1<T>,
1643    y: &ArrayView1<T>,
1644) -> InterpolateResult<EmpiricalBayesInterpolator<T>> {
1645    EmpiricalBayesInterpolator::new(x, y)
1646}