use crate::advanced::enhanced_kriging::TrendFunction;
use crate::advanced::kriging::CovarianceFunction;
use crate::error::{InterpolateError, InterpolateResult};
use crate::spatial::kdtree::KdTree;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Sub};
type SparseComponents<F> = (Vec<(usize, usize)>, Vec<F>);
const DEFAULT_MAX_NEIGHBORS: usize = 50;
const DEFAULT_RADIUS_MULTIPLIER: f64 = 3.0;
fn eval_covariance<F: Float + FromPrimitive>(r: F, sigma_sq: F, cov_fn: CovarianceFunction) -> F {
match cov_fn {
CovarianceFunction::SquaredExponential => sigma_sq * (-r * r).exp(),
CovarianceFunction::Exponential => sigma_sq * (-r).exp(),
CovarianceFunction::Matern32 => {
let sqrt3_r = F::from_f64(3.0_f64.sqrt()).expect("const") * r;
sigma_sq * (F::one() + sqrt3_r) * (-sqrt3_r).exp()
}
CovarianceFunction::Matern52 => {
let sqrt5_r = F::from_f64(5.0_f64.sqrt()).expect("const") * r;
let term = F::one() + sqrt5_r + F::from_f64(5.0 / 3.0).expect("const") * r * r;
sigma_sq * term * (-sqrt5_r).exp()
}
CovarianceFunction::RationalQuadratic => {
let alpha = F::one();
sigma_sq * (F::one() + r * r / (F::from_f64(2.0).expect("const") * alpha)).powf(-alpha)
}
}
}
fn euclidean_distance<F: Float>(a: &ArrayView1<F>, b: &ArrayView1<F>) -> F {
let mut sq = F::zero();
for (&ai, &bi) in a.iter().zip(b.iter()) {
let d = ai - bi;
sq = sq + d * d;
}
sq.sqrt()
}
fn wendland_c2<F: Float + FromPrimitive>(r: F, range: F) -> F {
if r >= range {
return F::zero();
}
let u = r / range;
let one_minus_u = F::one() - u;
let p4 = one_minus_u * one_minus_u * one_minus_u * one_minus_u;
let four = F::from_f64(4.0).expect("const");
p4 * (F::one() + four * u)
}
fn cholesky_lower<F: Float + FromPrimitive>(a: &Array2<F>) -> InterpolateResult<Array2<F>> {
let n = a.nrows();
let mut l = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut s = a[[i, j]];
for k in 0..j {
s = s - l[[i, k]] * l[[j, k]];
}
if i == j {
if s <= F::zero() {
return Err(InterpolateError::ComputationError(
"Cholesky: matrix not positive-definite".to_string(),
));
}
l[[i, j]] = s.sqrt();
} else {
l[[i, j]] = s / l[[j, j]];
}
}
}
Ok(l)
}
fn forward_sub<F: Float + FromPrimitive>(
l: &Array2<F>,
b: &Array1<F>,
) -> InterpolateResult<Array1<F>> {
let n = b.len();
let mut x = Array1::<F>::zeros(n);
for i in 0..n {
let mut s = b[i];
for j in 0..i {
s = s - l[[i, j]] * x[j];
}
let diag = l[[i, i]];
if diag.abs() < F::from_f64(1e-300).expect("const") {
return Err(InterpolateError::ComputationError(
"Forward substitution: near-zero diagonal element".to_string(),
));
}
x[i] = s / diag;
}
Ok(x)
}
fn back_sub_transpose<F: Float + FromPrimitive>(
l: &Array2<F>,
b: &Array1<F>,
) -> InterpolateResult<Array1<F>> {
let n = b.len();
let mut x = Array1::<F>::zeros(n);
for i in (0..n).rev() {
let mut s = b[i];
for j in (i + 1)..n {
s = s - l[[j, i]] * x[j];
}
let diag = l[[i, i]];
if diag.abs() < F::from_f64(1e-300).expect("const") {
return Err(InterpolateError::ComputationError(
"Back substitution: near-zero diagonal element".to_string(),
));
}
x[i] = s / diag;
}
Ok(x)
}
fn cholesky_solve<F: Float + FromPrimitive>(
a: &Array2<F>,
b: &Array1<F>,
) -> InterpolateResult<Array1<F>> {
let result = cholesky_lower(a).and_then(|l| {
let y = forward_sub(&l, b)?;
back_sub_transpose(&l, &y)
});
if result.is_ok() {
return result;
}
let n = a.nrows();
let mut reg = a.clone();
let mut trace = F::zero();
for i in 0..n {
trace = trace + a[[i, i]];
}
let eps = trace / F::from_usize(n).expect("const") * F::from_f64(1e-6).expect("const");
let eps = if eps < F::from_f64(1e-12).expect("const") {
F::from_f64(1e-12).expect("const")
} else {
eps
};
for i in 0..n {
reg[[i, i]] = reg[[i, i]] + eps;
}
let l = cholesky_lower(®)?;
let y = forward_sub(&l, b)?;
back_sub_transpose(&l, &y)
}
#[derive(Debug, Clone)]
struct NystromState<F> {
inducing_points: Array2<F>,
l_mm: Array2<F>,
kmi_kmy: Array1<F>,
rank: usize,
}
#[derive(Debug, Clone)]
struct TaperState<F> {
taper_range: F,
sparse: SparseComponents<F>,
}
#[derive(Debug, Clone)]
enum ApproxState<F: Float + Debug> {
None,
Nystrom(NystromState<F>),
Taper(TaperState<F>),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FastKrigingMethod {
Local,
FixedRank(usize),
Tapering(f64),
HODLR(usize),
}
#[derive(Debug, Clone)]
pub struct FastPredictionResult<F: Float> {
pub value: Array1<F>,
pub variance: Array1<F>,
pub method: FastKrigingMethod,
pub computation_time_ms: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct FastKriging<F>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Div<Output = F>
+ Mul<Output = F>
+ Sub<Output = F>
+ Add<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign,
{
points: Array2<F>,
values: Array1<F>,
cov_fn: CovarianceFunction,
length_scale: F,
sigma_sq: F,
nugget: F,
#[allow(dead_code)]
trend_fn: TrendFunction,
approx_method: FastKrigingMethod,
max_neighbors: usize,
#[allow(dead_code)]
radius_multiplier: F,
kdtree: Option<KdTree<F>>,
state: ApproxState<F>,
_phantom: PhantomData<F>,
}
#[derive(Debug, Clone)]
pub struct FastKrigingBuilder<F>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Div<Output = F>
+ Mul<Output = F>
+ Sub<Output = F>
+ Add<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign,
{
points: Option<Array2<F>>,
values: Option<Array1<F>>,
cov_fn: CovarianceFunction,
#[allow(dead_code)]
length_scales: Option<Array1<F>>,
length_scale: F,
sigma_sq: F,
nugget: F,
trend_fn: TrendFunction,
approx_method: FastKrigingMethod,
max_neighbors: usize,
radius_multiplier: F,
_phantom: PhantomData<F>,
}
impl<F> Default for FastKrigingBuilder<F>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<F> FastKrigingBuilder<F>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
pub fn new() -> Self {
Self {
points: None,
values: None,
cov_fn: CovarianceFunction::Matern52,
length_scales: None,
length_scale: F::from_f64(1.0).expect("const"),
sigma_sq: F::from_f64(1.0).expect("const"),
nugget: F::from_f64(1e-6).expect("const"),
trend_fn: TrendFunction::Constant,
approx_method: FastKrigingMethod::Local,
max_neighbors: DEFAULT_MAX_NEIGHBORS,
radius_multiplier: F::from_f64(DEFAULT_RADIUS_MULTIPLIER).expect("const"),
_phantom: PhantomData,
}
}
pub fn points(mut self, points: Array2<F>) -> Self {
self.points = Some(points);
self
}
pub fn values(mut self, values: Array1<F>) -> Self {
self.values = Some(values);
self
}
pub fn covariance_function(mut self, covfn: CovarianceFunction) -> Self {
self.cov_fn = covfn;
self
}
pub fn length_scales(mut self, lengthscales: Array1<F>) -> Self {
if let Some(&ls) = lengthscales.first() {
self.length_scale = ls;
}
self.length_scales = Some(lengthscales);
self
}
pub fn length_scale(mut self, lengthscale: F) -> Self {
self.length_scale = lengthscale;
self
}
pub fn sigma_sq(mut self, sigmasq: F) -> Self {
self.sigma_sq = sigmasq;
self
}
pub fn nugget(mut self, nugget: F) -> Self {
self.nugget = nugget;
self
}
pub fn trend_function(mut self, trendfn: TrendFunction) -> Self {
self.trend_fn = trendfn;
self
}
pub fn approximation_method(mut self, method: FastKrigingMethod) -> Self {
self.approx_method = method;
self
}
pub fn max_neighbors(mut self, maxneighbors: usize) -> Self {
self.max_neighbors = maxneighbors;
self
}
pub fn radius_multiplier(mut self, multiplier: F) -> Self {
self.radius_multiplier = multiplier;
self
}
pub fn build(self) -> InterpolateResult<FastKriging<F>> {
FastKriging::from_builder(self)
}
}
impl<F> FastKriging<F>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
pub fn builder() -> FastKrigingBuilder<F> {
FastKrigingBuilder::new()
}
fn from_builder(builder: FastKrigingBuilder<F>) -> InterpolateResult<FastKriging<F>> {
let points = builder.points.ok_or(InterpolateError::MissingPoints)?;
let values = builder.values.ok_or(InterpolateError::MissingValues)?;
if points.nrows() != values.len() {
return Err(InterpolateError::DimensionMismatch(
"Number of points must match number of values".to_string(),
));
}
if points.is_empty() {
return Err(InterpolateError::InvalidValue(
"Points array cannot be empty".to_string(),
));
}
let kdtree = KdTree::new(points.clone()).ok();
let state = build_approx_state(
&points,
&values,
builder.cov_fn,
builder.length_scale,
builder.sigma_sq,
builder.nugget,
builder.approx_method,
)?;
Ok(FastKriging {
points,
values,
cov_fn: builder.cov_fn,
length_scale: builder.length_scale,
sigma_sq: builder.sigma_sq,
nugget: builder.nugget,
trend_fn: builder.trend_fn,
approx_method: builder.approx_method,
max_neighbors: builder.max_neighbors,
radius_multiplier: builder.radius_multiplier,
kdtree,
state,
_phantom: PhantomData,
})
}
pub fn n_points(&self) -> usize {
self.points.nrows()
}
pub fn n_dims(&self) -> usize {
self.points.ncols()
}
pub fn approximation_method(&self) -> FastKrigingMethod {
self.approx_method
}
pub fn predict(
&self,
query_points: &ArrayView2<F>,
) -> InterpolateResult<FastPredictionResult<F>> {
if query_points.ncols() != self.points.ncols() {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimensionality {} does not match training dimensionality {}",
query_points.ncols(),
self.points.ncols()
)));
}
if query_points.nrows() == 0 {
return Ok(FastPredictionResult {
value: Array1::zeros(0),
variance: Array1::zeros(0),
method: self.approx_method,
computation_time_ms: None,
});
}
match self.approx_method {
FastKrigingMethod::Local => self.predict_local(query_points),
FastKrigingMethod::FixedRank(_) => self.predict_nystrom(query_points),
FastKrigingMethod::Tapering(_) => self.predict_tapered(query_points),
FastKrigingMethod::HODLR(_) => self.predict_hodlr(query_points),
}
}
fn predict_local(
&self,
query_points: &ArrayView2<F>,
) -> InterpolateResult<FastPredictionResult<F>> {
let n_query = query_points.nrows();
let mut pred_values = Array1::zeros(n_query);
let mut pred_variances = Array1::zeros(n_query);
let k = self.max_neighbors.min(self.points.nrows());
let global_mean = compute_mean(&self.values);
for qi in 0..n_query {
let query = query_points.slice(scirs2_core::ndarray::s![qi, ..]);
let neighbors = self.find_neighbors_kd(&query, k)?;
let m = neighbors.len();
if m == 0 {
pred_values[qi] = global_mean;
pred_variances[qi] = self.sigma_sq;
continue;
}
if m == 1 {
pred_values[qi] = self.values[neighbors[0].0];
pred_variances[qi] = F::zero();
continue;
}
let local_pts: Array2<F> = extract_rows(&self.points, &neighbors);
let local_vals: Array1<F> = {
let mut v = Array1::zeros(m);
for (j, &(idx, _)) in neighbors.iter().enumerate() {
v[j] = self.values[idx];
}
v
};
let k_local = self.build_cov_matrix(&local_pts);
let k_star = self.build_cross_cov(&query, &local_pts);
let weights = cholesky_solve(&k_local, &local_vals)
.unwrap_or_else(|_| uniform_weights(m, global_mean, &local_vals));
let mut pred = F::zero();
for j in 0..m {
pred = pred + k_star[j] * weights[j];
}
pred_values[qi] = pred;
let alpha = cholesky_solve(&k_local, &k_star).unwrap_or_else(|_| Array1::zeros(m));
let mut reduction = F::zero();
for j in 0..m {
reduction = reduction + k_star[j] * alpha[j];
}
let variance_raw = self.sigma_sq - reduction;
let variance = if variance_raw < F::zero() {
F::zero()
} else {
variance_raw
};
pred_variances[qi] = variance;
}
Ok(FastPredictionResult {
value: pred_values,
variance: pred_variances,
method: self.approx_method,
computation_time_ms: None,
})
}
fn predict_nystrom(
&self,
query_points: &ArrayView2<F>,
) -> InterpolateResult<FastPredictionResult<F>> {
let nys = match &self.state {
ApproxState::Nystrom(ns) => ns,
_ => {
return Err(InterpolateError::InvalidState(
"Nyström state not initialised for FixedRank method".to_string(),
))
}
};
let n_query = query_points.nrows();
let mut pred_values = Array1::zeros(n_query);
let mut pred_variances = Array1::zeros(n_query);
for qi in 0..n_query {
let query = query_points.slice(scirs2_core::ndarray::s![qi, ..]);
let k_qm = self.build_cross_cov(&query, &nys.inducing_points);
let mut pred = F::zero();
for j in 0..nys.rank {
pred = pred + k_qm[j] * nys.kmi_kmy[j];
}
pred_values[qi] = pred;
let alpha = back_sub_transpose(
&nys.l_mm,
&forward_sub(&nys.l_mm, &k_qm).unwrap_or_else(|_| Array1::zeros(nys.rank)),
)
.unwrap_or_else(|_| Array1::zeros(nys.rank));
let mut reduction = F::zero();
for j in 0..nys.rank {
reduction = reduction + k_qm[j] * alpha[j];
}
let var_nys = self.sigma_sq - reduction;
pred_variances[qi] = if var_nys < F::zero() {
F::zero()
} else {
var_nys
};
}
Ok(FastPredictionResult {
value: pred_values,
variance: pred_variances,
method: self.approx_method,
computation_time_ms: None,
})
}
fn predict_tapered(
&self,
query_points: &ArrayView2<F>,
) -> InterpolateResult<FastPredictionResult<F>> {
let taper_state = match &self.state {
ApproxState::Taper(ts) => ts,
_ => {
return Err(InterpolateError::InvalidState(
"Taper state not initialised for Tapering method".to_string(),
))
}
};
let n_query = query_points.nrows();
let mut pred_values = Array1::zeros(n_query);
let mut pred_variances = Array1::zeros(n_query);
let range = taper_state.taper_range;
let global_mean = compute_mean(&self.values);
for qi in 0..n_query {
let query = query_points.slice(scirs2_core::ndarray::s![qi, ..]);
let n_train = self.points.nrows();
let mut active: Vec<usize> = Vec::new();
let mut dists_q: Vec<F> = Vec::new();
for j in 0..n_train {
let pt = self.points.slice(scirs2_core::ndarray::s![j, ..]);
let dist = euclidean_distance(&query, &pt) / self.length_scale;
if dist < range / self.length_scale {
active.push(j);
dists_q.push(dist);
}
}
if active.is_empty() {
pred_values[qi] = global_mean;
pred_variances[qi] = self.sigma_sq;
continue;
}
let m = active.len();
let active_pts: Array2<F> = {
let mut ap = Array2::zeros((m, self.points.ncols()));
for (row, &idx) in active.iter().enumerate() {
ap.slice_mut(scirs2_core::ndarray::s![row, ..])
.assign(&self.points.slice(scirs2_core::ndarray::s![idx, ..]));
}
ap
};
let active_vals: Array1<F> = {
let mut av = Array1::zeros(m);
for (j, &idx) in active.iter().enumerate() {
av[j] = self.values[idx];
}
av
};
let mut k_local = Array2::<F>::zeros((m, m));
for j in 0..m {
for kk in 0..m {
let pt_j = active_pts.slice(scirs2_core::ndarray::s![j, ..]);
let pt_k = active_pts.slice(scirs2_core::ndarray::s![kk, ..]);
let dist = euclidean_distance(&pt_j, &pt_k) / self.length_scale;
let cov = eval_covariance(dist, self.sigma_sq, self.cov_fn);
let tap = wendland_c2(dist * self.length_scale, range);
if j == kk {
k_local[[j, kk]] = cov * tap + self.nugget;
} else {
k_local[[j, kk]] = cov * tap;
}
}
}
let mut k_star = Array1::zeros(m);
for (j, &dist_scaled) in dists_q.iter().enumerate() {
let dist_abs = dist_scaled * self.length_scale;
let cov = eval_covariance(dist_scaled, self.sigma_sq, self.cov_fn);
let tap = wendland_c2(dist_abs, range);
k_star[j] = cov * tap;
}
let weights = cholesky_solve(&k_local, &active_vals)
.unwrap_or_else(|_| uniform_weights(m, global_mean, &active_vals));
let mut pred = F::zero();
for j in 0..m {
pred = pred + k_star[j] * weights[j];
}
pred_values[qi] = pred;
let alpha = cholesky_solve(&k_local, &k_star).unwrap_or_else(|_| Array1::zeros(m));
let mut reduction = F::zero();
for j in 0..m {
reduction = reduction + k_star[j] * alpha[j];
}
let var_tap = self.sigma_sq - reduction;
pred_variances[qi] = if var_tap < F::zero() {
F::zero()
} else {
var_tap
};
}
let _ = &taper_state.sparse;
Ok(FastPredictionResult {
value: pred_values,
variance: pred_variances,
method: self.approx_method,
computation_time_ms: None,
})
}
fn predict_hodlr(
&self,
query_points: &ArrayView2<F>,
) -> InterpolateResult<FastPredictionResult<F>> {
let leaf_size = match self.approx_method {
FastKrigingMethod::HODLR(ls) => ls.max(2),
_ => 32,
};
let n_train = self.points.nrows();
let n_query = query_points.nrows();
let mut pred_values = Array1::zeros(n_query);
let mut pred_variances = Array1::zeros(n_query);
let global_mean = compute_mean(&self.values);
let n_blocks = (n_train + leaf_size - 1) / leaf_size;
for qi in 0..n_query {
let query = query_points.slice(scirs2_core::ndarray::s![qi, ..]);
let mut total_weight = F::zero();
let mut weighted_pred = F::zero();
let mut weighted_var = F::zero();
for b in 0..n_blocks {
let start = b * leaf_size;
let end = n_train.min(start + leaf_size);
if start >= end {
continue;
}
let d = self.points.ncols();
let mut centroid = vec![F::zero(); d];
for j in start..end {
for dd in 0..d {
centroid[dd] = centroid[dd] + self.points[[j, dd]];
}
}
let block_len = F::from_usize(end - start).expect("const");
for dd in 0..d {
centroid[dd] = centroid[dd] / block_len;
}
let mut dist_sq = F::zero();
for dd in 0..d {
let diff = query[dd] - centroid[dd];
dist_sq = dist_sq + diff * diff;
}
let dist = dist_sq.sqrt();
let weight = F::one() / (F::one() + dist);
if weight < F::from_f64(1e-8).expect("const") {
continue;
}
let block_pts_slice = self.points.slice(scirs2_core::ndarray::s![start..end, ..]);
let block_pts = block_pts_slice.to_owned();
let block_vals_slice = self.values.slice(scirs2_core::ndarray::s![start..end]);
let block_vals = block_vals_slice.to_owned();
let (local_pred, local_var) =
self.block_local_predict(&query, &block_pts, &block_vals, global_mean)?;
weighted_pred = weighted_pred + weight * local_pred;
weighted_var = weighted_var + weight * weight * local_var;
total_weight = total_weight + weight;
}
if total_weight > F::zero() {
pred_values[qi] = weighted_pred / total_weight;
let raw_var = weighted_var / (total_weight * total_weight);
pred_variances[qi] = if raw_var < F::zero() {
F::zero()
} else {
raw_var
};
} else {
pred_values[qi] = global_mean;
pred_variances[qi] = self.sigma_sq;
}
}
Ok(FastPredictionResult {
value: pred_values,
variance: pred_variances,
method: self.approx_method,
computation_time_ms: None,
})
}
fn block_local_predict(
&self,
query: &ArrayView1<F>,
block_pts: &Array2<F>,
block_vals: &Array1<F>,
global_mean: F,
) -> InterpolateResult<(F, F)> {
let m = block_pts.nrows();
if m == 0 {
return Ok((global_mean, self.sigma_sq));
}
if m == 1 {
return Ok((block_vals[0], F::zero()));
}
let k_local = self.build_cov_matrix(block_pts);
let k_star = self.build_cross_cov(query, block_pts);
let weights = cholesky_solve(&k_local, block_vals)
.unwrap_or_else(|_| uniform_weights(m, global_mean, block_vals));
let mut pred = F::zero();
for j in 0..m {
pred = pred + k_star[j] * weights[j];
}
let alpha = cholesky_solve(&k_local, &k_star).unwrap_or_else(|_| Array1::zeros(m));
let mut reduction = F::zero();
for j in 0..m {
reduction = reduction + k_star[j] * alpha[j];
}
let var_raw = self.sigma_sq - reduction;
let var = if var_raw < F::zero() {
F::zero()
} else {
var_raw
};
Ok((pred, var))
}
fn find_neighbors_kd(
&self,
query: &ArrayView1<F>,
k: usize,
) -> InterpolateResult<Vec<(usize, F)>> {
let query_slice = query.as_slice().ok_or_else(|| {
InterpolateError::InvalidValue("Query must be contiguous".to_string())
})?;
match &self.kdtree {
Some(tree) => tree.k_nearest_neighbors(query_slice, k),
None => {
let n = self.points.nrows();
let mut dists: Vec<(usize, F)> = (0..n)
.map(|i| {
let pt = self.points.slice(scirs2_core::ndarray::s![i, ..]);
let d = euclidean_distance(query, &pt);
(i, d)
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
dists.truncate(k);
Ok(dists)
}
}
}
fn build_cov_matrix(&self, pts: &Array2<F>) -> Array2<F> {
let m = pts.nrows();
let mut mat = Array2::zeros((m, m));
for i in 0..m {
for j in 0..m {
if i == j {
mat[[i, j]] = self.sigma_sq + self.nugget;
} else {
let pi = pts.slice(scirs2_core::ndarray::s![i, ..]);
let pj = pts.slice(scirs2_core::ndarray::s![j, ..]);
let r = euclidean_distance(&pi, &pj) / self.length_scale;
mat[[i, j]] = eval_covariance(r, self.sigma_sq, self.cov_fn);
}
}
}
mat
}
fn build_cross_cov(&self, query: &ArrayView1<F>, pts: &Array2<F>) -> Array1<F> {
let m = pts.nrows();
let mut kv = Array1::zeros(m);
for j in 0..m {
let pj = pts.slice(scirs2_core::ndarray::s![j, ..]);
let r = euclidean_distance(query, &pj) / self.length_scale;
kv[j] = eval_covariance(r, self.sigma_sq, self.cov_fn);
}
kv
}
}
fn build_approx_state<F>(
points: &Array2<F>,
values: &Array1<F>,
cov_fn: CovarianceFunction,
length_scale: F,
sigma_sq: F,
nugget: F,
method: FastKrigingMethod,
) -> InterpolateResult<ApproxState<F>>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
match method {
FastKrigingMethod::FixedRank(rank) => {
let nys =
build_nystrom_state(points, values, cov_fn, length_scale, sigma_sq, nugget, rank)?;
Ok(ApproxState::Nystrom(nys))
}
FastKrigingMethod::Tapering(range_f64) => {
let range = F::from_f64(range_f64)
.ok_or_else(|| InterpolateError::InvalidValue("Invalid taper range".to_string()))?;
let sparse =
build_tapered_sparse(points, cov_fn, length_scale, sigma_sq, nugget, range)?;
Ok(ApproxState::Taper(TaperState {
taper_range: range,
sparse,
}))
}
_ => Ok(ApproxState::None),
}
}
fn build_nystrom_state<F>(
points: &Array2<F>,
values: &Array1<F>,
cov_fn: CovarianceFunction,
length_scale: F,
sigma_sq: F,
nugget: F,
rank: usize,
) -> InterpolateResult<NystromState<F>>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ std::ops::AddAssign
+ 'static,
{
let n = points.nrows();
let m = rank.min(n);
let step = if m > 1 { n / m } else { 1 };
let inducing_indices: Vec<usize> = (0..m).map(|i| (i * step).min(n - 1)).collect();
let d = points.ncols();
let mut ind_pts = Array2::zeros((m, d));
for (row, &idx) in inducing_indices.iter().enumerate() {
ind_pts
.slice_mut(scirs2_core::ndarray::s![row, ..])
.assign(&points.slice(scirs2_core::ndarray::s![idx, ..]));
}
let mut k_mm = Array2::zeros((m, m));
for i in 0..m {
for j in 0..m {
if i == j {
k_mm[[i, j]] = sigma_sq + nugget;
} else {
let pi = ind_pts.slice(scirs2_core::ndarray::s![i, ..]);
let pj = ind_pts.slice(scirs2_core::ndarray::s![j, ..]);
let r = euclidean_distance(&pi, &pj) / length_scale;
k_mm[[i, j]] = eval_covariance(r, sigma_sq, cov_fn);
}
}
}
let l_mm = cholesky_lower(&k_mm)?;
let mut k_mn_y = Array1::zeros(m);
for i in 0..m {
let pi = ind_pts.slice(scirs2_core::ndarray::s![i, ..]);
let mut dot = F::zero();
for j in 0..n {
let pj = points.slice(scirs2_core::ndarray::s![j, ..]);
let r = euclidean_distance(&pi, &pj) / length_scale;
dot = dot + eval_covariance(r, sigma_sq, cov_fn) * values[j];
}
k_mn_y[i] = dot;
}
let y_fwd = forward_sub(&l_mm, &k_mn_y)?;
let kmi_kmy = back_sub_transpose(&l_mm, &y_fwd)?;
Ok(NystromState {
inducing_points: ind_pts,
l_mm,
kmi_kmy,
rank: m,
})
}
fn build_tapered_sparse<F>(
points: &Array2<F>,
cov_fn: CovarianceFunction,
length_scale: F,
sigma_sq: F,
nugget: F,
taper_range: F,
) -> InterpolateResult<SparseComponents<F>>
where
F: Float + FromPrimitive + ordered_float::FloatCore + std::ops::AddAssign + 'static,
{
let n = points.nrows();
let mut indices: Vec<(usize, usize)> = Vec::new();
let mut vals: Vec<F> = Vec::new();
for i in 0..n {
for j in 0..=i {
let pi = points.slice(scirs2_core::ndarray::s![i, ..]);
let pj = points.slice(scirs2_core::ndarray::s![j, ..]);
let dist = euclidean_distance(&pi, &pj);
let dist_scaled = dist / length_scale;
let tap = wendland_c2(dist, taper_range);
if tap == F::zero() && i != j {
continue; }
let cov = eval_covariance(dist_scaled, sigma_sq, cov_fn);
let entry = if i == j {
cov * tap + nugget
} else {
cov * tap
};
indices.push((i, j));
vals.push(entry);
if i != j {
indices.push((j, i));
vals.push(entry);
}
}
}
Ok((indices, vals))
}
pub fn make_local_kriging<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
cov_fn: CovarianceFunction,
scale: F,
max_neighbors: usize,
) -> InterpolateResult<FastKriging<F>>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
FastKrigingBuilder::new()
.points(points.to_owned())
.values(values.to_owned())
.covariance_function(cov_fn)
.length_scale(scale)
.approximation_method(FastKrigingMethod::Local)
.max_neighbors(max_neighbors)
.build()
}
pub fn make_fixed_rank_kriging<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
rank: usize,
cov_fn: CovarianceFunction,
scale: F,
) -> InterpolateResult<FastKriging<F>>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
FastKrigingBuilder::new()
.points(points.to_owned())
.values(values.to_owned())
.covariance_function(cov_fn)
.length_scale(scale)
.approximation_method(FastKrigingMethod::FixedRank(rank))
.build()
}
pub fn make_tapered_kriging<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
taper_range: F,
cov_fn: CovarianceFunction,
scale: F,
) -> InterpolateResult<FastKriging<F>>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
let range_f64 = taper_range.to_f64().ok_or_else(|| {
InterpolateError::InvalidValue("Cannot convert taper_range to f64".to_string())
})?;
FastKrigingBuilder::new()
.points(points.to_owned())
.values(values.to_owned())
.covariance_function(cov_fn)
.length_scale(scale)
.approximation_method(FastKrigingMethod::Tapering(range_f64))
.build()
}
pub fn make_hodlr_kriging<F>(
points: &ArrayView2<F>,
values: &ArrayView1<F>,
leaf_size: usize,
cov_fn: CovarianceFunction,
scale: F,
) -> InterpolateResult<FastKriging<F>>
where
F: Float
+ FromPrimitive
+ ordered_float::FloatCore
+ Debug
+ Display
+ Add<Output = F>
+ Sub<Output = F>
+ Mul<Output = F>
+ Div<Output = F>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static,
{
FastKrigingBuilder::new()
.points(points.to_owned())
.values(values.to_owned())
.covariance_function(cov_fn)
.length_scale(scale)
.approximation_method(FastKrigingMethod::HODLR(leaf_size))
.build()
}
pub fn select_approximation_method(n_points: usize) -> FastKrigingMethod {
if n_points < 500 {
FastKrigingMethod::Local
} else if n_points < 5_000 {
FastKrigingMethod::FixedRank(50)
} else if n_points < 50_000 {
FastKrigingMethod::Tapering(3.0)
} else {
FastKrigingMethod::HODLR(64)
}
}
fn compute_mean<F: Float + FromPrimitive>(values: &Array1<F>) -> F {
if values.is_empty() {
return F::zero();
}
let n = F::from_usize(values.len()).expect("const");
let mut sum = F::zero();
for &v in values.iter() {
sum = sum + v;
}
sum / n
}
fn extract_rows<F: Float>(pts: &Array2<F>, neighbors: &[(usize, F)]) -> Array2<F> {
let m = neighbors.len();
let d = pts.ncols();
let mut out = Array2::zeros((m, d));
for (row, &(idx, _)) in neighbors.iter().enumerate() {
out.slice_mut(scirs2_core::ndarray::s![row, ..])
.assign(&pts.slice(scirs2_core::ndarray::s![idx, ..]));
}
out
}
fn uniform_weights<F: Float + FromPrimitive>(
m: usize,
_global_mean: F,
vals: &Array1<F>,
) -> Array1<F> {
let n = F::from_usize(m).expect("const");
let mut w = Array1::zeros(m);
let sum_vals: F = vals.iter().fold(F::zero(), |acc, &v| acc + v);
if sum_vals.abs() > F::from_f64(1e-300).expect("const") {
for j in 0..m {
w[j] = F::one() / n;
}
} else {
for j in 0..m {
w[j] = F::one() / n;
}
}
w
}
#[cfg(test)]
#[path = "fast_kriging_reexports_tests.rs"]
mod tests;