use crate::error::InterpolateError;
#[derive(Debug, Clone)]
pub struct DataProfile {
pub n_points: usize,
pub n_dims: usize,
pub smoothness_estimate: f64,
pub has_noise: bool,
pub is_periodic: bool,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum InterpolationMethod {
LinearSpline,
CubicSpline,
RadialBasis,
TensorProduct,
SparseGrid,
TensorTrain,
}
impl std::fmt::Display for InterpolationMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
InterpolationMethod::LinearSpline => "LinearSpline",
InterpolationMethod::CubicSpline => "CubicSpline",
InterpolationMethod::RadialBasis => "RadialBasis",
InterpolationMethod::TensorProduct => "TensorProduct",
InterpolationMethod::SparseGrid => "SparseGrid",
InterpolationMethod::TensorTrain => "TensorTrain",
};
write!(f, "{}", name)
}
}
pub fn analyze_data(x: &[Vec<f64>], y: &[f64]) -> DataProfile {
let n_points = x.len();
let n_dims = if n_points > 0 { x[0].len() } else { 0 };
if n_points < 3 || n_dims == 0 {
return DataProfile {
n_points,
n_dims,
smoothness_estimate: 0.0,
has_noise: false,
is_periodic: false,
};
}
let mut order: Vec<usize> = (0..n_points).collect();
order.sort_by(|&a, &b| {
x[a][0]
.partial_cmp(&x[b][0])
.unwrap_or(std::cmp::Ordering::Equal)
});
let y_sorted: Vec<f64> = order.iter().map(|&i| y[i]).collect();
let rms_y = (y_sorted.iter().map(|&v| v * v).sum::<f64>() / n_points as f64)
.sqrt()
.max(1e-12);
let second_diff_rms = if n_points >= 3 {
let n = y_sorted.len();
let ss: f64 = (1..(n - 1))
.map(|i| {
let d2 = y_sorted[i + 1] - 2.0 * y_sorted[i] + y_sorted[i - 1];
d2 * d2
})
.sum::<f64>();
(ss / (n - 2) as f64).sqrt()
} else {
0.0
};
let smoothness_estimate = second_diff_rms / rms_y;
let has_noise = smoothness_estimate > 0.3;
let y_max_abs = y_sorted
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max)
.max(1e-12);
let endpoint_diff = (y_sorted[0] - y_sorted[n_points - 1]).abs();
let is_periodic = endpoint_diff / y_max_abs < 0.05;
DataProfile {
n_points,
n_dims,
smoothness_estimate,
has_noise,
is_periodic,
}
}
pub fn recommend_method(profile: &DataProfile) -> InterpolationMethod {
let d = profile.n_dims;
let n = profile.n_points;
if d == 1 && !profile.has_noise {
return InterpolationMethod::CubicSpline;
}
if d > 10 {
return InterpolationMethod::TensorTrain;
}
if d <= 4 && n < 500 {
return InterpolationMethod::RadialBasis;
}
if d <= 6 && n < 10_000 {
return InterpolationMethod::TensorProduct;
}
if d > 6 && n > 1_000 {
return InterpolationMethod::SparseGrid;
}
InterpolationMethod::RadialBasis
}
pub fn recommend_with_rationale(profile: &DataProfile) -> (InterpolationMethod, String) {
let d = profile.n_dims;
let n = profile.n_points;
if d == 1 && !profile.has_noise {
return (
InterpolationMethod::CubicSpline,
format!(
"1-D data ({n} points) without noise: CubicSpline gives smooth, \
C² interpolation at O(n) cost."
),
);
}
if d > 10 {
return (
InterpolationMethod::TensorTrain,
format!(
"{d}-D data ({n} points): dimensionality exceeds 10; \
TensorTrain (TT-SVD/TT-cross) avoids the curse of dimensionality."
),
);
}
if d <= 4 && n < 500 {
return (
InterpolationMethod::RadialBasis,
format!(
"{d}-D scattered data ({n} points): RBF provides flexible \
interpolation without a grid structure."
),
);
}
if d <= 6 && n < 10_000 {
return (
InterpolationMethod::TensorProduct,
format!(
"{d}-D data ({n} points): a tensor-product grid is feasible \
and gives fast O(n) evaluation per dimension."
),
);
}
if d > 6 && n > 1_000 {
return (
InterpolationMethod::SparseGrid,
format!(
"{d}-D data ({n} points): Smolyak sparse grid reduces the \
exponential cost of tensor-product methods in moderate dimensions."
),
);
}
(
InterpolationMethod::RadialBasis,
format!(
"Default choice for {d}-D data ({n} points): RBF interpolation \
works well for general scattered data."
),
)
}
#[allow(dead_code)]
pub(crate) fn validate_input(x: &[Vec<f64>], y: &[f64]) -> Result<(), InterpolateError> {
if x.len() != y.len() {
return Err(InterpolateError::DimensionMismatch(format!(
"x has {} points but y has {} values",
x.len(),
y.len()
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_1d_data(n: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
let x: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64 / n as f64]).collect();
let y: Vec<f64> = x.iter().map(|p| p[0] * p[0]).collect();
(x, y)
}
fn make_nd_data(n: usize, d: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
let x: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64 / n as f64; d]).collect();
let y: Vec<f64> = x.iter().map(|p| p.iter().sum::<f64>()).collect();
(x, y)
}
#[test]
fn test_1d_smooth_data_recommends_cubic_spline() {
let (x, y) = make_1d_data(50);
let profile = analyze_data(&x, &y);
assert_eq!(profile.n_dims, 1);
let method = recommend_method(&profile);
assert_eq!(method, InterpolationMethod::CubicSpline);
}
#[test]
fn test_high_dim_recommends_tensor_train() {
let (x, y) = make_nd_data(2000, 15);
let profile = analyze_data(&x, &y);
let method = recommend_method(&profile);
assert_eq!(method, InterpolationMethod::TensorTrain);
}
#[test]
fn test_moderate_dim_recommends_sparse_grid() {
let (x, y) = make_nd_data(2000, 8);
let profile = analyze_data(&x, &y);
let method = recommend_method(&profile);
assert_eq!(method, InterpolationMethod::SparseGrid);
}
#[test]
fn test_small_4d_recommends_rbf() {
let (x, y) = make_nd_data(100, 4);
let profile = analyze_data(&x, &y);
let method = recommend_method(&profile);
assert_eq!(method, InterpolationMethod::RadialBasis);
}
#[test]
fn test_recommend_with_rationale_returns_string() {
let (x, y) = make_1d_data(20);
let profile = analyze_data(&x, &y);
let (method, reason) = recommend_with_rationale(&profile);
assert_eq!(method, InterpolationMethod::CubicSpline);
assert!(!reason.is_empty(), "rationale string should not be empty");
}
#[test]
fn test_analyze_data_smoothness_for_noisy_data() {
let x: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64 * 0.1]).collect();
let y: Vec<f64> = (0..20)
.map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
.collect();
let profile = analyze_data(&x, &y);
assert!(
profile.has_noise,
"alternating data should be flagged as noisy"
);
}
#[test]
fn test_periodicity_detected() {
use std::f64::consts::PI;
let n = 65_usize; let x: Vec<Vec<f64>> = (0..n)
.map(|i| vec![i as f64 * 2.0 * PI / (n - 1) as f64])
.collect();
let y: Vec<f64> = x.iter().map(|p| p[0].sin()).collect();
let profile = analyze_data(&x, &y);
assert!(
profile.is_periodic,
"sin data on [0,2π] should be detected as periodic; y[0]={:.4}, y[last]={:.4}",
y[0],
y[n - 1]
);
}
#[test]
fn test_empty_data_no_panic() {
let profile = analyze_data(&[], &[]);
assert_eq!(profile.n_points, 0);
assert_eq!(profile.n_dims, 0);
}
}