use crate::advanced::enhanced_kriging::{AnisotropicCovariance, TrendFunction};
use crate::advanced::kriging::CovarianceFunction;
use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Sub};
type SparseComponents<F> = (Vec<(usize, usize)>, Vec<F>);
const DEFAULT_MAX_NEIGHBORS: usize = 50;
const DEFAULT_RADIUS_MULTIPLIER: f64 = 3.0;
#[derive(Debug, Clone)]
pub struct FastPredictionResult<F>
where
F: Float + FromPrimitive + Debug + Display,
{
pub value: Array1<F>,
pub variance: Array1<F>,
pub method: FastKrigingMethod,
pub computation_time_ms: Option<f64>,
}
impl<F> FastPredictionResult<F>
where
F: Float + FromPrimitive + Debug + Display,
{
pub fn len(&self) -> usize {
self.value.len()
}
pub fn is_empty(&self) -> bool {
self.value.is_empty()
}
pub fn values(&self) -> &Array1<F> {
&self.value
}
pub fn variances(&self) -> &Array1<F> {
&self.variance
}
pub fn standard_deviations(&self) -> Array1<F> {
self.variance.mapv(|v| v.sqrt())
}
pub fn confidence_intervals(&self, confidencelevel: f64) -> InterpolateResult<Array2<F>> {
if confidence_level <= 0.0 || confidence_level >= 1.0 {
return Err(InterpolateError::InvalidValue(
"Confidence _level must be between 0 and 1".to_string(),
));
}
let z_score = F::from_f64(match confidence_level {
_level if _level > 0.99 => 2.576, _level if _level > 0.95 => 1.96, _level if _level > 0.90 => 1.645, _level if _level > 0.80 => 1.282, _ => 1.96, })
.expect("Operation failed");
let std_devs = self.standard_deviations();
let mut intervals = Array2::zeros((self.len(), 2));
for i in 0..self.len() {
let margin = z_score * std_devs[i];
intervals[[i, 0]] = self.value[i] - margin; intervals[[i, 1]] = self.value[i] + margin; }
Ok(intervals)
}
}
pub mod acceleration;
pub mod covariance;
pub mod ordinary;
pub mod universal;
pub mod variogram;
pub use acceleration::*;
pub use covariance::*;
pub use ordinary::*;
pub use universal::*;
pub use variogram::*;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FastKrigingMethod {
Local,
FixedRank(usize),
Tapering(f64),
HODLR(usize), }
#[derive(Debug, Clone)]
pub struct FastKriging<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Display
+ Div<Output = F>
+ Mul<Output = F>
+ Sub<Output = F>
+ Add<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign,
{
points: Array2<F>,
values: Array1<F>,
anisotropic_cov: AnisotropicCovariance<F>,
trend_fn: TrendFunction,
approx_method: FastKrigingMethod,
max_neighbors: usize,
radius_multiplier: F,
low_rank_components: Option<(Array2<F>, Array1<F>, Array2<F>)>,
sparse_components: Option<SparseComponents<F>>,
weights: Array1<F>,
basis_functions: Option<Array2<F>>,
trend_coeffs: Option<Array1<F>>,
optimize_parameters: bool,
compute_exact_variance: bool,
_phantom: PhantomData<F>,
}
#[derive(Debug, Clone)]
pub struct FastKrigingBuilder<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Display
+ Div<Output = F>
+ Mul<Output = F>
+ Sub<Output = F>
+ Add<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign,
{
points: Option<Array2<F>>,
values: Option<Array1<F>>,
cov_fn: CovarianceFunction,
length_scales: Option<Array1<F>>,
sigma_sq: F,
nugget: F,
trend_fn: TrendFunction,
approx_method: FastKrigingMethod,
max_neighbors: usize,
radius_multiplier: F,
_phantom: PhantomData<F>,
}
impl<F> FastKriging<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
pub(crate) fn from_builder(builder: FastKrigingBuilder<F>) -> InterpolateResult<Self> {
let points = builder.points.ok_or(InterpolateError::MissingPoints)?;
let values = builder.values.ok_or(InterpolateError::MissingValues)?;
if points.nrows() != values.len() {
return Err(InterpolateError::DimensionMismatch(
"Number of points must match number of values".to_string(),
));
}
let anisotropic_cov = AnisotropicCovariance::new(
builder.cov_fn,
_builder
.length_scales
.unwrap_or_else(|| Array1::from_elem(points.ncols(), F::one())),
builder.sigma_sq,
builder.nugget,
);
let mut kriging = Self {
points,
values,
anisotropic_cov,
trend_fn: builder.trend_fn,
approx_method: builder.approx_method,
max_neighbors: builder.max_neighbors,
radius_multiplier: builder.radius_multiplier,
low_rank_components: None,
sparse_components: None,
weights: Array1::zeros(0),
basis_functions: None,
trend_coeffs: None,
optimize_parameters: false,
compute_exact_variance: false, _phantom: PhantomData,
};
kriging.initialize_approximation()?;
Ok(kriging)
}
fn initialize_approximation(&mut self) -> InterpolateResult<()> {
match self.approx_method {
FastKrigingMethod::FixedRank(rank) => {
let (u, s, v) = covariance::compute_low_rank_approximation(
&self.points,
&self.anisotropic_cov,
rank,
)?;
self.low_rank_components = Some((u, s, v));
}
FastKrigingMethod::Tapering(range) => {
let sparse_components = covariance::compute_tapered_covariance(
&self.points,
&self.anisotropic_cov,
F::from_f64(range).expect("Operation failed"),
)?;
self.sparse_components = Some(sparse_components);
}
_ => {
}
}
if self.trend_fn != crate::advanced::enhanced_kriging::TrendFunction::Constant {
let basis = universal::create_basis_functions(&self.points.view(), self.trend_fn)?;
let coeffs = universal::compute_trend_coefficients(
&self.points,
&self.values,
&basis,
self.trend_fn,
)?;
self.basis_functions = Some(basis);
self.trend_coeffs = Some(coeffs);
}
Ok(())
}
pub fn predict(
&self,
query_points: &ArrayView2<F>,
) -> InterpolateResult<FastPredictionResult<F>> {
if query_points.ncols() != self._points.ncols() {
return Err(InterpolateError::DimensionMismatch(
"Query _points must have same dimensionality as training _points".to_string(),
));
}
match self.approx_method {
FastKrigingMethod::Local => self.predict_local(query_points),
FastKrigingMethod::FixedRank(_) => self.predict_fixed_rank(query_points),
FastKrigingMethod::Tapering(_) => self.predict_tapered(query_points),
FastKrigingMethod::HODLR(_) => self.predict_hodlr(query_points),
}
}
pub fn n_points(&self) -> usize {
self.points.nrows()
}
pub fn n_dims(&self) -> usize {
self.points.ncols()
}
pub fn approximation_method(&self) -> FastKrigingMethod {
self.approx_method
}
}
impl<F> Default for FastKrigingBuilder<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<F> FastKrigingBuilder<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
pub fn new() -> Self {
Self {
points: None,
values: None,
cov_fn: CovarianceFunction::Matern52,
length_scales: None,
sigma_sq: F::from_f64(1.0).expect("Operation failed"),
nugget: F::from_f64(1e-6).expect("Operation failed"),
trend_fn: TrendFunction::Constant,
approx_method: FastKrigingMethod::Local,
max_neighbors: DEFAULT_MAX_NEIGHBORS,
radius_multiplier: F::from_f64(DEFAULT_RADIUS_MULTIPLIER).expect("Operation failed"), _phantom: PhantomData,
}
}
pub fn points(mut self, points: Array2<F>) -> Self {
self.points = Some(points);
self
}
pub fn values(mut self, values: Array1<F>) -> Self {
self.values = Some(values);
self
}
pub fn covariance_function(mut self, covfn: CovarianceFunction) -> Self {
self.cov_fn = cov_fn;
self
}
pub fn length_scales(mut self, lengthscales: Array1<F>) -> Self {
self.length_scales = Some(length_scales);
self
}
pub fn sigma_sq(mut self, sigmasq: F) -> Self {
self.sigma_sq = sigma_sq;
self
}
pub fn nugget(mut self, nugget: F) -> Self {
self.nugget = nugget;
self
}
pub fn trend_function(mut self, trendfn: TrendFunction) -> Self {
self.trend_fn = trend_fn;
self
}
pub fn approximation_method(mut self, method: FastKrigingMethod) -> Self {
self.approx_method = method;
self
}
pub fn max_neighbors(mut self, maxneighbors: usize) -> Self {
self.max_neighbors = max_neighbors;
self
}
pub fn radius_multiplier(mut self, radiusmultiplier: F) -> Self {
self.radius_multiplier = radius_multiplier;
self
}
pub fn build(self) -> InterpolateResult<FastKriging<F>> {
FastKriging::from_builder(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::advanced::kriging::CovarianceFunction;
use scirs2_core::ndarray::{Array1, Array2};
fn create_test_data(_n_points: usize, ndims: usize) -> (Array2<f64>, Array1<f64>) {
let mut _points = Array2::zeros((_n_points, n_dims));
let mut values = Array1::zeros(_n_points);
for i in 0.._n_points {
for d in 0..n_dims {
points[[i, d]] = (i as f64) / (_n_points as f64) + (d as f64) * 0.1;
}
let x = points[[i, 0]];
let y = if n_dims > 1 { points[[i, 1]] } else { 0.0 };
values[i] = x * x + y * y + 0.1 * x * y;
}
(_points, values)
}
#[test]
fn test_fast_prediction_result_methods() {
let value = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let variance = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let result = FastPredictionResult {
value,
variance,
method: FastKrigingMethod::Local,
computation_time_ms: Some(100.0),
};
assert_eq!(result.len(), 3);
assert!(!result.is_empty());
assert_eq!(result.values().len(), 3);
assert_eq!(result.variances().len(), 3);
let std_devs = result.standard_deviations();
assert!((std_devs[0] - 0.1_f64.sqrt()).abs() < 1e-10);
}
#[test]
fn test_confidence_intervals() {
let value = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let variance = Array1::from_vec(vec![0.01, 0.04, 0.09]);
let result = FastPredictionResult {
value,
variance,
method: FastKrigingMethod::Local,
computation_time_ms: None,
};
let intervals = result.confidence_intervals(0.95).expect("Operation failed");
assert_eq!(intervals.nrows(), 3);
assert_eq!(intervals.ncols(), 2);
for i in 0..3 {
assert!(intervals[[i, 0]] < intervals[[i, 1]]);
}
assert!(result.confidence_intervals(1.1).is_err());
assert!(result.confidence_intervals(-0.1).is_err());
}
#[test]
fn test_fast_kriging_method_variants() {
assert_eq!(FastKrigingMethod::Local, FastKrigingMethod::Local);
assert_ne!(FastKrigingMethod::Local, FastKrigingMethod::FixedRank(10));
let method1 = FastKrigingMethod::FixedRank(10);
let method2 = FastKrigingMethod::FixedRank(10);
let method3 = FastKrigingMethod::FixedRank(20);
assert_eq!(method1, method2);
assert_ne!(method1, method3);
}
#[test]
fn test_fast_kriging_builder_default() {
let builder = FastKrigingBuilder::<f64>::new();
assert!(builder.points.is_none());
assert!(builder.values.is_none());
assert_eq!(builder.cov_fn, CovarianceFunction::Matern52);
assert_eq!(builder.approx_method, FastKrigingMethod::Local);
assert_eq!(builder.max_neighbors, DEFAULT_MAX_NEIGHBORS);
}
#[test]
fn test_fast_kriging_builder_methods() {
let (points, values) = create_test_data(10, 2);
let length_scales = Array1::from_vec(vec![1.0, 1.5]);
let builder = FastKrigingBuilder::<f64>::new()
.points(points.clone())
.values(values.clone())
.covariance_function(CovarianceFunction::Exponential)
.length_scales(length_scales.clone())
.sigma_sq(2.0)
.nugget(0.01)
.approximation_method(FastKrigingMethod::FixedRank(5))
.max_neighbors(15)
.radius_multiplier(2.0);
assert!(builder.points.is_some());
assert!(builder.values.is_some());
assert_eq!(builder.cov_fn, CovarianceFunction::Exponential);
assert_eq!(builder.sigma_sq, 2.0);
assert_eq!(builder.nugget, 0.01);
assert_eq!(builder.max_neighbors, 15);
assert_eq!(builder.radius_multiplier, 2.0);
}
#[test]
fn test_fast_kriging_build_missing_data() {
let builder = FastKrigingBuilder::<f64>::new();
let result = builder.build();
assert!(result.is_err());
let (points_) = create_test_data(5, 2);
let builder = FastKrigingBuilder::<f64>::new().points(points);
let result = builder.build();
assert!(result.is_err());
}
#[test]
fn test_fast_kriging_dimension_mismatch() {
let points = Array2::zeros((5, 2));
let values = Array1::zeros(3);
let result = FastKrigingBuilder::<f64>::new()
.points(points)
.values(values)
.build();
assert!(result.is_err());
}
#[cfg(feature = "linalg")]
#[test]
fn test_local_kriging() {
let (points, values) = create_test_data(20, 2);
let kriging = FastKrigingBuilder::<f64>::new()
.points(points.clone())
.values(values.clone())
.covariance_function(CovarianceFunction::Matern52)
.approximation_method(FastKrigingMethod::Local)
.max_neighbors(10)
.build()
.expect("Operation failed");
assert_eq!(kriging.n_points(), 20);
assert_eq!(kriging.n_dims(), 2);
assert_eq!(kriging.approximation_method(), FastKrigingMethod::Local);
let query_points = Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.75, 0.75]).expect("Operation failed");
let result = kriging.predict(&query_points.view()).expect("Operation failed");
assert_eq!(result.len(), 2);
assert_eq!(result.method, FastKrigingMethod::Local);
for &val in result.values().iter() {
assert!(val.is_finite());
}
for &var in result.variances().iter() {
assert!(var >= 0.0);
}
}
#[cfg(feature = "linalg")]
#[test]
fn test_fixed_rank_kriging() {
let (points, values) = create_test_data(30, 2);
let kriging = FastKrigingBuilder::<f64>::new()
.points(points)
.values(values)
.covariance_function(CovarianceFunction::Exponential)
.approximation_method(FastKrigingMethod::FixedRank(5))
.build()
.expect("Operation failed");
let query_points =
Array2::from_shape_vec((3, 2), vec![0.1, 0.1, 0.5, 0.5, 0.9, 0.9]).expect("Operation failed");
let result = kriging.predict(&query_points.view()).expect("Operation failed");
assert_eq!(result.len(), 3);
assert_eq!(result.method, FastKrigingMethod::FixedRank(5));
for &val in result.values().iter() {
assert!(val.is_finite());
}
}
#[cfg(feature = "linalg")]
#[test]
fn test_tapered_kriging() {
let (points, values) = create_test_data(25, 2);
let kriging = FastKrigingBuilder::<f64>::new()
.points(points)
.values(values)
.covariance_function(CovarianceFunction::SquaredExponential)
.approximation_method(FastKrigingMethod::Tapering(1.5))
.build()
.expect("Operation failed");
let query_points = Array2::from_shape_vec((2, 2), vec![0.3, 0.3, 0.7, 0.7]).expect("Operation failed");
let result = kriging.predict(&query_points.view()).expect("Operation failed");
assert_eq!(result.len(), 2);
assert_eq!(result.method, FastKrigingMethod::Tapering(1.5));
for &val in result.values().iter() {
assert!(val.is_finite());
}
for &var in result.variances().iter() {
assert!(var >= 0.0);
}
}
#[cfg(feature = "linalg")]
#[test]
fn test_hodlr_kriging() {
let (points, values) = create_test_data(40, 2);
let kriging = FastKrigingBuilder::<f64>::new()
.points(points)
.values(values)
.covariance_function(CovarianceFunction::Matern32)
.approximation_method(FastKrigingMethod::HODLR(8))
.build()
.expect("Operation failed");
let query_points = Array2::from_shape_vec((2, 2), vec![0.4, 0.4, 0.6, 0.6]).expect("Operation failed");
let result = kriging.predict(&query_points.view()).expect("Operation failed");
assert_eq!(result.len(), 2);
assert_eq!(result.method, FastKrigingMethod::HODLR(8));
for &val in result.values().iter() {
assert!(val.is_finite());
}
}
#[test]
fn test_predict_dimension_mismatch() {
let (points, values) = create_test_data(10, 2);
let kriging = FastKrigingBuilder::<f64>::new()
.points(points)
.values(values)
.approximation_method(FastKrigingMethod::Local)
.build()
.expect("Operation failed");
let wrong_query = Array2::zeros((2, 3));
let result = kriging.predict(&wrong_query.view());
assert!(result.is_err());
}
#[test]
fn test_empty_query_prediction() {
let (points, values) = create_test_data(10, 2);
let kriging = FastKrigingBuilder::<f64>::new()
.points(points)
.values(values)
.approximation_method(FastKrigingMethod::Local)
.build()
.expect("Operation failed");
let empty_query = Array2::zeros((0, 2));
let result = kriging.predict(&empty_query.view()).expect("Operation failed");
assert_eq!(result.len(), 0);
assert!(result.is_empty());
}
#[test]
fn test_single_point_dataset() {
let points = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).expect("Operation failed");
let values = Array1::from_vec(vec![1.0]);
let kriging = FastKrigingBuilder::<f64>::new()
.points(points)
.values(values)
.approximation_method(FastKrigingMethod::Local)
.max_neighbors(1)
.build()
.expect("Operation failed");
let query_points = Array2::from_shape_vec((1, 2), vec![0.6, 0.6]).expect("Operation failed");
let result = kriging.predict(&query_points.view());
assert!(result.is_ok());
}
}