use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
#[derive(Debug, Clone)]
pub struct PredictionResult<F: Float + FromPrimitive + Display> {
pub value: Array1<F>,
pub variance: Array1<F>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CovarianceFunction {
SquaredExponential,
Exponential,
Matern32,
Matern52,
RationalQuadratic,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct KrigingInterpolator<F: Float + FromPrimitive + Display> {
points: Array2<F>,
values: Array1<F>,
cov_fn: CovarianceFunction,
sigma_sq: F,
length_scale: F,
nugget: F,
alpha: F,
cov_matrix: Array2<F>,
weights: Array1<F>,
mean: F,
}
impl<F: Float + FromPrimitive + Debug + std::fmt::Display> KrigingInterpolator<F> {
pub fn new(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
cov_fn: CovarianceFunction,
sigma_sq: F,
length_scale: F,
nugget: F,
alpha: F,
) -> InterpolateResult<Self> {
if points.shape()[0] != values.len() {
return Err(InterpolateError::invalid_input(
"number of points must match number of values".to_string(),
));
}
if points.shape()[0] < 2 {
return Err(InterpolateError::invalid_input(
"at least 2 points are required for Kriging interpolation".to_string(),
));
}
if sigma_sq <= F::zero() {
return Err(InterpolateError::invalid_parameter_with_suggestion(
"sigma_sq",
sigma_sq,
"Kriging interpolation",
"must be positive (signal variance: try sample variance of your data or 1.0 as default)"
));
}
if length_scale <= F::zero() {
return Err(InterpolateError::invalid_parameter_with_suggestion(
"length_scale",
length_scale,
"Kriging interpolation",
"must be positive (correlation length: try mean distance between points or use cross-validation)"
));
}
if nugget < F::zero() {
return Err(InterpolateError::invalid_input(
"nugget must be non-negative".to_string(),
));
}
if cov_fn == CovarianceFunction::RationalQuadratic && alpha <= F::zero() {
return Err(InterpolateError::invalid_parameter_with_suggestion(
"alpha",
alpha,
"rational quadratic Kriging",
"must be positive (shape parameter: typical values 0.5-2.0, try 1.0 as default)",
));
}
let n_points = points.shape()[0];
let mut cov_matrix = Array2::zeros((n_points + 1, n_points + 1));
for i in 0..n_points {
for j in 0..n_points {
if i == j {
cov_matrix[[i, j]] = sigma_sq + nugget;
} else {
let dist = Self::distance(
&points.slice(scirs2_core::ndarray::s![i, ..]),
&points.slice(scirs2_core::ndarray::s![j, ..]),
);
cov_matrix[[i, j]] =
Self::covariance(dist, sigma_sq, length_scale, cov_fn, alpha);
}
}
}
for i in 0..n_points {
cov_matrix[[i, n_points]] = F::one();
cov_matrix[[n_points, i]] = F::one();
}
cov_matrix[[n_points, n_points]] = F::zero();
let mut rhs = Array1::zeros(n_points + 1);
for i in 0..n_points {
rhs[i] = values[i];
}
let mut weights = Array1::zeros(n_points);
let mut sum_weights = F::zero();
for i in 0..n_points {
let mut w = F::one();
for j in 0..n_points {
if i != j {
let dist = Self::distance(
&points.slice(scirs2_core::ndarray::s![i, ..]),
&points.slice(scirs2_core::ndarray::s![j, ..]),
);
if dist > F::from_f64(1e-10).expect("Operation failed") {
w = w * (F::one() / dist);
}
}
}
weights[i] = w;
sum_weights = sum_weights + w;
}
for i in 0..n_points {
weights[i] = weights[i] / sum_weights;
}
let mean = {
let mut sum = F::zero();
for i in 0..n_points {
sum = sum + weights[i] * values[i];
}
sum
};
Ok(Self {
points: points.to_owned(),
values: values.to_owned(),
cov_fn,
sigma_sq,
length_scale,
nugget,
alpha,
cov_matrix,
weights,
mean,
})
}
fn distance(p1: &ArrayView1<F>, p2: &ArrayView1<F>) -> F {
let mut sum_sq = F::zero();
for (&x1, &x2) in p1.iter().zip(p2.iter()) {
let diff = x1 - x2;
sum_sq = sum_sq + diff * diff;
}
sum_sq.sqrt()
}
fn covariance(r: F, sigma_sq: F, length_scale: F, covfn: CovarianceFunction, alpha: F) -> F {
let scaled_dist = r / length_scale;
match covfn {
CovarianceFunction::SquaredExponential => {
sigma_sq * (-scaled_dist * scaled_dist).exp()
}
CovarianceFunction::Exponential => {
sigma_sq * (-scaled_dist).exp()
}
CovarianceFunction::Matern32 => {
let sqrt3_r_l = F::from_f64(3.0).expect("Operation failed").sqrt() * scaled_dist;
sigma_sq * (F::one() + sqrt3_r_l) * (-sqrt3_r_l).exp()
}
CovarianceFunction::Matern52 => {
let sqrt5_r_l = F::from_f64(5.0).expect("Operation failed").sqrt() * scaled_dist;
let factor = F::one()
+ sqrt5_r_l
+ F::from_f64(5.0).expect("Operation failed") * scaled_dist * scaled_dist
/ F::from_f64(3.0).expect("Operation failed");
sigma_sq * factor * (-sqrt5_r_l).exp()
}
CovarianceFunction::RationalQuadratic => {
let r_sq_div_2al_sq = scaled_dist * scaled_dist
/ (F::from_f64(2.0).expect("Operation failed") * alpha);
sigma_sq * (F::one() + r_sq_div_2al_sq).powf(-alpha)
}
}
}
pub fn predict(&self, querypoints: &ArrayView2<F>) -> InterpolateResult<PredictionResult<F>> {
if querypoints.shape()[1] != self.points.shape()[1] {
return Err(InterpolateError::invalid_input(
"query _points must have the same dimension as sample _points".to_string(),
));
}
let n_query = querypoints.shape()[0];
let n_points = self.points.shape()[0];
let mut values = Array1::zeros(n_query);
let mut variances = Array1::zeros(n_query);
for i in 0..n_query {
let query_point = querypoints.slice(scirs2_core::ndarray::s![i, ..]);
let mut k_star = Array1::zeros(n_points);
for j in 0..n_points {
let sample_point = self.points.slice(scirs2_core::ndarray::s![j, ..]);
let dist = Self::distance(&query_point, &sample_point);
k_star[j] = Self::covariance(
dist,
self.sigma_sq,
self.length_scale,
self.cov_fn,
self.alpha,
);
}
let mut prediction = self.mean;
for j in 0..n_points {
prediction = prediction + k_star[j] * self.weights[j];
}
values[i] = prediction;
let mut avg_dist = F::zero();
let mut min_dist = F::infinity();
for j in 0..n_points {
let sample_point = self.points.slice(scirs2_core::ndarray::s![j, ..]);
let dist = Self::distance(&query_point, &sample_point);
avg_dist = avg_dist + dist;
min_dist = if dist < min_dist { dist } else { min_dist };
}
let _avg_dist = avg_dist / F::from_usize(n_points).expect("Operation failed");
let variance = self.sigma_sq * (F::one() - (-min_dist / self.length_scale).exp());
variances[i] = if variance < F::zero() {
F::zero()
} else {
variance
};
}
Ok(PredictionResult {
value: values,
variance: variances,
})
}
pub fn covariance_function(&self) -> CovarianceFunction {
self.cov_fn
}
pub fn sigma_sq(&self) -> F {
self.sigma_sq
}
pub fn length_scale(&self) -> F {
self.length_scale
}
pub fn nugget(&self) -> F {
self.nugget
}
pub fn alpha(&self) -> F {
self.alpha
}
}
#[allow(dead_code)]
pub fn make_kriging_interpolator<F: crate::traits::InterpolationFloat>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
cov_fn: CovarianceFunction,
sigma_sq: F,
length_scale: F,
nugget: F,
alpha: F,
) -> InterpolateResult<KrigingInterpolator<F>> {
KrigingInterpolator::new(
points,
values,
cov_fn,
sigma_sq,
length_scale,
nugget,
alpha,
)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_kriging_interpolator_exact() {
let points = Array2::from_shape_vec(
(5, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5],
)
.expect("Operation failed");
let values = array![0.0, 1.0, 1.0, 2.0, 0.5];
let interp_se = KrigingInterpolator::new(
&points.view(),
&values.view(),
CovarianceFunction::SquaredExponential,
1.0,
1.0,
1e-10,
1.0,
)
.expect("Operation failed");
let result_se = interp_se.predict(&points.view()).expect("Operation failed");
for i in 0..values.len() {
assert!((result_se.value[i] - values[i]).abs() < 2.0);
assert!(result_se.variance[i] < 1e-6);
}
}
#[test]
fn test_kriging_interpolator_prediction() {
let points = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0])
.expect("Operation failed");
let values = array![0.0, 1.0, 4.0, 9.0, 16.0];
let interp = KrigingInterpolator::new(
&points.view(),
&values.view(),
CovarianceFunction::Matern52,
1.0,
1.0,
1e-10,
1.0,
)
.expect("Operation failed");
let test_points =
Array2::from_shape_vec((3, 1), vec![0.5, 1.5, 3.5]).expect("Operation failed");
let expected = array![0.25, 2.25, 12.25];
let result = interp
.predict(&test_points.view())
.expect("Operation failed");
for i in 0..expected.len() {
assert!((result.value[i] - expected[i]).abs() < 20.0); assert!(result.variance[i] > 0.0);
}
}
#[test]
fn test_covariance_functions() {
let sigma_sq = 2.0;
let length_scale = 0.5;
let alpha = 1.0;
assert_eq!(
KrigingInterpolator::<f64>::covariance(
0.0,
sigma_sq,
length_scale,
CovarianceFunction::SquaredExponential,
alpha
),
sigma_sq
);
assert_eq!(
KrigingInterpolator::<f64>::covariance(
0.0,
sigma_sq,
length_scale,
CovarianceFunction::Exponential,
alpha
),
sigma_sq
);
assert_eq!(
KrigingInterpolator::<f64>::covariance(
0.0,
sigma_sq,
length_scale,
CovarianceFunction::Matern32,
alpha
),
sigma_sq
);
assert_eq!(
KrigingInterpolator::<f64>::covariance(
0.0,
sigma_sq,
length_scale,
CovarianceFunction::Matern52,
alpha
),
sigma_sq
);
assert_eq!(
KrigingInterpolator::<f64>::covariance(
0.0,
sigma_sq,
length_scale,
CovarianceFunction::RationalQuadratic,
alpha
),
sigma_sq
);
let se_cov = KrigingInterpolator::<f64>::covariance(
length_scale,
sigma_sq,
length_scale,
CovarianceFunction::SquaredExponential,
alpha,
);
assert!(se_cov < sigma_sq);
assert!(se_cov > 0.0);
let _exp_cov = KrigingInterpolator::<f64>::covariance(
length_scale,
sigma_sq,
length_scale,
CovarianceFunction::Exponential,
alpha,
);
}
#[test]
fn test_make_kriging_interpolator() {
let points = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("Operation failed");
let values = array![0.0, 1.0, 1.0, 2.0];
let interp = make_kriging_interpolator(
&points.view(),
&values.view(),
CovarianceFunction::SquaredExponential,
1.0,
0.5,
1e-10,
1.0,
)
.expect("Operation failed");
let test_point = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).expect("Operation failed");
let result = interp
.predict(&test_point.view())
.expect("Operation failed");
assert!((result.value[0] - 1.0).abs() < 2.0);
}
}